From 4381ad503ca3e83510b876281fc768d00d40d499 Mon Sep 17 00:00:00 2001 From: Christopher Wilcox Date: Thu, 22 Jul 2021 10:14:32 -0700 Subject: [PATCH 01/19] fix: move to using insecure grpc channels with emulator (#402) * fix: move to using insecure grpc channels with emulator * chore: format * fix: add code to manually inject the id token on an insecure channel * chore: add line for comment * test: use the correct credentials object in mock * chore: black * chore: unused var * always configure the bearer token, even if not available * test: test the path populating an id token * chore: remove unused code and testing of unused code * chore: remove some code repetition * chore: feedback --- google/cloud/firestore_v1/base_client.py | 54 +++++++----------------- tests/unit/v1/test_base_client.py | 38 +++++++++-------- 2 files changed, 35 insertions(+), 57 deletions(-) diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index b2af21e3f6..7eb5c26b08 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -167,50 +167,26 @@ def _firestore_api_helper(self, transport, client_class, client_module) -> Any: def _emulator_channel(self, transport): """ - Creates a channel using self._credentials in a similar way to grpc.secure_channel but - using grpc.local_channel_credentials() rather than grpc.ssh_channel_credentials() to allow easy connection - to a local firestore emulator. This allows local testing of firestore rules if the credentials have been - created from a signed custom token. + Creates an insecure channel to communicate with the local emulator. + If credentials are provided the token is extracted and added to the + headers. This supports local testing of firestore rules if the credentials + have been created from a signed custom token. :return: grpc.Channel or grpc.aio.Channel """ - # TODO: Implement a special credentials type for emulator and use - # "transport.create_channel" to create gRPC channels once google-auth - # extends it's allowed credentials types. + # Insecure channels are used for the emulator as secure channels + # cannot be used to communicate on some environments. + # https://github.com/googleapis/python-firestore/issues/359 + # Default the token to a non-empty string, in this case "owner". + token = "owner" + if self._credentials is not None and self._credentials.id_token is not None: + token = self._credentials.id_token + options = [("Authorization", f"Bearer {token}")] + if "GrpcAsyncIOTransport" in str(transport.__name__): - return grpc.aio.secure_channel( - self._emulator_host, self._local_composite_credentials() - ) + return grpc.aio.insecure_channel(self._emulator_host, options=options) else: - return grpc.secure_channel( - self._emulator_host, self._local_composite_credentials() - ) - - def _local_composite_credentials(self): - """ - Creates the credentials for the local emulator channel - :return: grpc.ChannelCredentials - """ - credentials = google.auth.credentials.with_scopes_if_required( - self._credentials, None - ) - request = google.auth.transport.requests.Request() - - # Create the metadata plugin for inserting the authorization header. - metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin( - credentials, request - ) - - # Create a set of grpc.CallCredentials using the metadata plugin. - google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) - - # Using the local_credentials to allow connection to emulator - local_credentials = grpc.local_channel_credentials() - - # Combine the local credentials and the authorization credentials. - return grpc.composite_channel_credentials( - local_credentials, google_auth_credentials - ) + return grpc.insecure_channel(self._emulator_host, options=options) def _target_helper(self, client_class) -> str: """Return the target (where the API is). diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index fd176d7603..5de0e4962a 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -146,11 +146,11 @@ def test_emulator_channel(self): ) emulator_host = "localhost:8081" + credentials = _make_credentials() + database = "quanta" with mock.patch("os.getenv") as getenv: getenv.return_value = emulator_host - - credentials = _make_credentials() - database = "quanta" + credentials.id_token = None client = self._make_one( project=self.PROJECT, credentials=credentials, database=database ) @@ -160,21 +160,23 @@ def test_emulator_channel(self): self.assertTrue(isinstance(channel, grpc.Channel)) channel = client._emulator_channel(FirestoreGrpcAsyncIOTransport) self.assertTrue(isinstance(channel, grpc.aio.Channel)) - # checks that the credentials are composite ones using a local channel from grpc - composite_credentials = client._local_composite_credentials() - self.assertTrue(isinstance(composite_credentials, grpc.ChannelCredentials)) - self.assertTrue( - isinstance( - composite_credentials._credentials._call_credentialses[0], - grpc._cython.cygrpc.MetadataPluginCallCredentials, + + # Verify that when credentials are provided with an id token it is used + # for channel construction + # NOTE: On windows, emulation requires an insecure channel. If this is + # altered to use a secure channel, start by verifying that it still + # works as expected on windows. + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + credentials.id_token = "test" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database ) - ) - self.assertTrue( - isinstance( - composite_credentials._credentials._channel_credentials, - grpc._cython.cygrpc.LocalChannelCredentials, + with mock.patch("grpc.insecure_channel") as insecure_channel: + channel = client._emulator_channel(FirestoreGrpcTransport) + insecure_channel.assert_called_once_with( + emulator_host, options=[("Authorization", "Bearer test")] ) - ) def test_field_path(self): klass = self._get_target_class() @@ -392,9 +394,9 @@ def test_paths(self): def _make_credentials(): - import google.auth.credentials + import google.oauth2.credentials - return mock.Mock(spec=google.auth.credentials.Credentials) + return mock.Mock(spec=google.oauth2.credentials.Credentials) def _make_batch_response(**kwargs): From 7e0f8401794b7c45d5a49cc5b274d180ac1dfb75 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Fri, 23 Jul 2021 15:32:46 +0000 Subject: [PATCH 02/19] chore: fix kokoro config for samples (#404) Source-Link: https://github.com/googleapis/synthtool/commit/dd05f9d12f134871c9e45282349c9856fbebecdd Post-Processor: gcr.io/repo-automation-bots/owlbot-python:latest@sha256:aea14a583128771ae8aefa364e1652f3c56070168ef31beb203534222d842b8b --- .github/.OwlBot.lock.yaml | 2 +- .kokoro/samples/python3.6/periodic-head.cfg | 2 +- .kokoro/samples/python3.7/periodic-head.cfg | 2 +- .kokoro/samples/python3.8/periodic-head.cfg | 2 +- .kokoro/samples/python3.9/periodic-head.cfg | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index d57f742046..9ee60f7e48 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -1,3 +1,3 @@ docker: image: gcr.io/repo-automation-bots/owlbot-python:latest - digest: sha256:6186535cbdbf6b9fe61f00294929221d060634dae4a0795c1cefdbc995b2d605 + digest: sha256:aea14a583128771ae8aefa364e1652f3c56070168ef31beb203534222d842b8b diff --git a/.kokoro/samples/python3.6/periodic-head.cfg b/.kokoro/samples/python3.6/periodic-head.cfg index f9cfcd33e0..21998d0902 100644 --- a/.kokoro/samples/python3.6/periodic-head.cfg +++ b/.kokoro/samples/python3.6/periodic-head.cfg @@ -7,5 +7,5 @@ env_vars: { env_vars: { key: "TRAMPOLINE_BUILD_FILE" - value: "github/python-pubsub/.kokoro/test-samples-against-head.sh" + value: "github/python-firestore/.kokoro/test-samples-against-head.sh" } diff --git a/.kokoro/samples/python3.7/periodic-head.cfg b/.kokoro/samples/python3.7/periodic-head.cfg index f9cfcd33e0..21998d0902 100644 --- a/.kokoro/samples/python3.7/periodic-head.cfg +++ b/.kokoro/samples/python3.7/periodic-head.cfg @@ -7,5 +7,5 @@ env_vars: { env_vars: { key: "TRAMPOLINE_BUILD_FILE" - value: "github/python-pubsub/.kokoro/test-samples-against-head.sh" + value: "github/python-firestore/.kokoro/test-samples-against-head.sh" } diff --git a/.kokoro/samples/python3.8/periodic-head.cfg b/.kokoro/samples/python3.8/periodic-head.cfg index f9cfcd33e0..21998d0902 100644 --- a/.kokoro/samples/python3.8/periodic-head.cfg +++ b/.kokoro/samples/python3.8/periodic-head.cfg @@ -7,5 +7,5 @@ env_vars: { env_vars: { key: "TRAMPOLINE_BUILD_FILE" - value: "github/python-pubsub/.kokoro/test-samples-against-head.sh" + value: "github/python-firestore/.kokoro/test-samples-against-head.sh" } diff --git a/.kokoro/samples/python3.9/periodic-head.cfg b/.kokoro/samples/python3.9/periodic-head.cfg index f9cfcd33e0..21998d0902 100644 --- a/.kokoro/samples/python3.9/periodic-head.cfg +++ b/.kokoro/samples/python3.9/periodic-head.cfg @@ -7,5 +7,5 @@ env_vars: { env_vars: { key: "TRAMPOLINE_BUILD_FILE" - value: "github/python-pubsub/.kokoro/test-samples-against-head.sh" + value: "github/python-firestore/.kokoro/test-samples-against-head.sh" } From 8703b48c45e7bb742a794cad9597740c44182f81 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Sat, 24 Jul 2021 10:16:22 +0000 Subject: [PATCH 03/19] fix: enable self signed jwt for grpc (#405) PiperOrigin-RevId: 386504689 Source-Link: https://github.com/googleapis/googleapis/commit/762094a99ac6e03a17516b13dfbef37927267a70 Source-Link: https://github.com/googleapis/googleapis-gen/commit/6bfc480e1a161d5de121c2bcc3745885d33b265a --- .../services/firestore_admin/client.py | 4 +++ .../firestore_v1/services/firestore/client.py | 4 +++ .../test_firestore_admin.py | 31 +++++++++++-------- .../unit/gapic/firestore_v1/test_firestore.py | 29 ++++++++++------- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py index 490b9465ea..7f34c8e30a 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py @@ -399,6 +399,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_index( diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 126723d505..1a74fc874a 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -351,6 +351,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def get_document( diff --git a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py index 6bcb9d73a0..d16690ce3d 100644 --- a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py +++ b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py @@ -132,18 +132,6 @@ def test_firestore_admin_client_from_service_account_info(client_class): assert client.transport._host == "firestore.googleapis.com:443" -@pytest.mark.parametrize( - "client_class", [FirestoreAdminClient, FirestoreAdminAsyncClient,] -) -def test_firestore_admin_client_service_account_always_use_jwt(client_class): - with mock.patch.object( - service_account.Credentials, "with_always_use_jwt_access", create=True - ) as use_jwt: - creds = service_account.Credentials(None, None, None) - client = client_class(credentials=creds) - use_jwt.assert_not_called() - - @pytest.mark.parametrize( "transport_class,transport_name", [ @@ -151,7 +139,7 @@ def test_firestore_admin_client_service_account_always_use_jwt(client_class): (transports.FirestoreAdminGrpcAsyncIOTransport, "grpc_asyncio"), ], ) -def test_firestore_admin_client_service_account_always_use_jwt_true( +def test_firestore_admin_client_service_account_always_use_jwt( transport_class, transport_name ): with mock.patch.object( @@ -161,6 +149,13 @@ def test_firestore_admin_client_service_account_always_use_jwt_true( transport = transport_class(credentials=creds, always_use_jwt_access=True) use_jwt.assert_called_once_with(True) + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + @pytest.mark.parametrize( "client_class", [FirestoreAdminClient, FirestoreAdminAsyncClient,] @@ -241,6 +236,7 @@ def test_firestore_admin_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -257,6 +253,7 @@ def test_firestore_admin_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -273,6 +270,7 @@ def test_firestore_admin_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -301,6 +299,7 @@ def test_firestore_admin_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -367,6 +366,7 @@ def test_firestore_admin_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -400,6 +400,7 @@ def test_firestore_admin_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -421,6 +422,7 @@ def test_firestore_admin_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -451,6 +453,7 @@ def test_firestore_admin_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -481,6 +484,7 @@ def test_firestore_admin_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -500,6 +504,7 @@ def test_firestore_admin_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index 3220d06720..de0f39a82a 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -121,16 +121,6 @@ def test_firestore_client_from_service_account_info(client_class): assert client.transport._host == "firestore.googleapis.com:443" -@pytest.mark.parametrize("client_class", [FirestoreClient, FirestoreAsyncClient,]) -def test_firestore_client_service_account_always_use_jwt(client_class): - with mock.patch.object( - service_account.Credentials, "with_always_use_jwt_access", create=True - ) as use_jwt: - creds = service_account.Credentials(None, None, None) - client = client_class(credentials=creds) - use_jwt.assert_not_called() - - @pytest.mark.parametrize( "transport_class,transport_name", [ @@ -138,7 +128,7 @@ def test_firestore_client_service_account_always_use_jwt(client_class): (transports.FirestoreGrpcAsyncIOTransport, "grpc_asyncio"), ], ) -def test_firestore_client_service_account_always_use_jwt_true( +def test_firestore_client_service_account_always_use_jwt( transport_class, transport_name ): with mock.patch.object( @@ -148,6 +138,13 @@ def test_firestore_client_service_account_always_use_jwt_true( transport = transport_class(credentials=creds, always_use_jwt_access=True) use_jwt.assert_called_once_with(True) + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + @pytest.mark.parametrize("client_class", [FirestoreClient, FirestoreAsyncClient,]) def test_firestore_client_from_service_account_file(client_class): @@ -222,6 +219,7 @@ def test_firestore_client_client_options(client_class, transport_class, transpor client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -238,6 +236,7 @@ def test_firestore_client_client_options(client_class, transport_class, transpor client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -254,6 +253,7 @@ def test_firestore_client_client_options(client_class, transport_class, transpor client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -282,6 +282,7 @@ def test_firestore_client_client_options(client_class, transport_class, transpor client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -346,6 +347,7 @@ def test_firestore_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -379,6 +381,7 @@ def test_firestore_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -400,6 +403,7 @@ def test_firestore_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -430,6 +434,7 @@ def test_firestore_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -460,6 +465,7 @@ def test_firestore_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -477,6 +483,7 @@ def test_firestore_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) From 509648a9f0f9375a681923d1e38b2998631aba9d Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Tue, 27 Jul 2021 14:18:19 -0700 Subject: [PATCH 04/19] refactor: added BaseQuery._copy method (#406) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: added BaseQuery.copy method * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * responded to code review * migrated last copy location * moved _not_passed check to identity instead of equality Co-authored-by: Owl Bot --- google/cloud/firestore_v1/base_query.py | 109 +++++++++--------------- 1 file changed, 39 insertions(+), 70 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index aafdab979c..5d11ccb3c0 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -85,6 +85,8 @@ ) _MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}." +_not_passed = object() + class BaseQuery(object): """Represents a query to the Firestore API. @@ -231,19 +233,41 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery": for field_path in field_paths ] ) + return self._copy(projection=new_projection) + + def _copy( + self, + *, + projection: Optional[query.StructuredQuery.Projection] = _not_passed, + field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed, + orders: Optional[Tuple[query.StructuredQuery.Order]] = _not_passed, + limit: Optional[int] = _not_passed, + limit_to_last: Optional[bool] = _not_passed, + offset: Optional[int] = _not_passed, + start_at: Optional[Tuple[dict, bool]] = _not_passed, + end_at: Optional[Tuple[dict, bool]] = _not_passed, + all_descendants: Optional[bool] = _not_passed, + ) -> "BaseQuery": return self.__class__( self._parent, - projection=new_projection, - field_filters=self._field_filters, - orders=self._orders, - limit=self._limit, - limit_to_last=self._limit_to_last, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, + projection=self._evaluate_param(projection, self._projection), + field_filters=self._evaluate_param(field_filters, self._field_filters), + orders=self._evaluate_param(orders, self._orders), + limit=self._evaluate_param(limit, self._limit), + limit_to_last=self._evaluate_param(limit_to_last, self._limit_to_last), + offset=self._evaluate_param(offset, self._offset), + start_at=self._evaluate_param(start_at, self._start_at), + end_at=self._evaluate_param(end_at, self._end_at), + all_descendants=self._evaluate_param( + all_descendants, self._all_descendants + ), ) + def _evaluate_param(self, value, fallback_value): + """Helper which allows `None` to be passed into `copy` and be set on the + copy instead of being misinterpreted as an unpassed parameter.""" + return value if value is not _not_passed else fallback_value + def where(self, field_path: str, op_string: str, value) -> "BaseQuery": """Filter the query on a field. @@ -301,18 +325,7 @@ def where(self, field_path: str, op_string: str, value) -> "BaseQuery": ) new_filters = self._field_filters + (filter_pb,) - return self.__class__( - self._parent, - projection=self._projection, - field_filters=new_filters, - orders=self._orders, - limit=self._limit, - offset=self._offset, - limit_to_last=self._limit_to_last, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) + return self._copy(field_filters=new_filters) @staticmethod def _make_order(field_path, direction) -> StructuredQuery.Order: @@ -354,18 +367,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery": order_pb = self._make_order(field_path, direction) new_orders = self._orders + (order_pb,) - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=new_orders, - limit=self._limit, - limit_to_last=self._limit_to_last, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) + return self._copy(orders=new_orders) def limit(self, count: int) -> "BaseQuery": """Limit a query to return at most `count` matching results. @@ -384,18 +386,7 @@ def limit(self, count: int) -> "BaseQuery": A limited query. Acts as a copy of the current query, modified with the newly added "limit" filter. """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=count, - limit_to_last=False, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) + return self._copy(limit=count, limit_to_last=False) def limit_to_last(self, count: int) -> "BaseQuery": """Limit a query to return the last `count` matching results. @@ -414,18 +405,7 @@ def limit_to_last(self, count: int) -> "BaseQuery": A limited query. Acts as a copy of the current query, modified with the newly added "limit" filter. """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=count, - limit_to_last=True, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) + return self._copy(limit=count, limit_to_last=True) def offset(self, num_to_skip: int) -> "BaseQuery": """Skip to an offset in a query. @@ -442,18 +422,7 @@ def offset(self, num_to_skip: int) -> "BaseQuery": An offset query. Acts as a copy of the current query, modified with the newly added "offset" field. """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=self._limit, - limit_to_last=self._limit_to_last, - offset=num_to_skip, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) + return self._copy(offset=num_to_skip) def _check_snapshot(self, document_snapshot) -> None: """Validate local snapshots for non-collection-group queries. @@ -523,7 +492,7 @@ def _cursor_helper( query_kwargs["start_at"] = self._start_at query_kwargs["end_at"] = cursor_pair - return self.__class__(self._parent, **query_kwargs) + return self._copy(**query_kwargs) def start_at( self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] From 9b905cbbfc5321b2776a82d55f0154fa117abab5 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Tue, 27 Jul 2021 14:21:50 -0700 Subject: [PATCH 05/19] refactor: added BaseQuery._copy method (#406) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: added BaseQuery.copy method * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * responded to code review * migrated last copy location * moved _not_passed check to identity instead of equality Co-authored-by: Owl Bot From ae4148e9516c3512441ad7ab7ab9df0699b81399 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 5 Aug 2021 11:15:52 -0400 Subject: [PATCH 06/19] tests: split systests out to separate Kokoro job (#412) Closes #411. --- .kokoro/presubmit/presubmit.cfg | 8 +++++++- .kokoro/presubmit/system-3.7.cfg | 7 +++++++ owlbot.py | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 .kokoro/presubmit/system-3.7.cfg diff --git a/.kokoro/presubmit/presubmit.cfg b/.kokoro/presubmit/presubmit.cfg index 8f43917d92..b158096f0a 100644 --- a/.kokoro/presubmit/presubmit.cfg +++ b/.kokoro/presubmit/presubmit.cfg @@ -1 +1,7 @@ -# Format: //devtools/kokoro/config/proto/build.proto \ No newline at end of file +# Format: //devtools/kokoro/config/proto/build.proto + +# Disable system tests. +env_vars: { + key: "RUN_SYSTEM_TESTS" + value: "false" +} diff --git a/.kokoro/presubmit/system-3.7.cfg b/.kokoro/presubmit/system-3.7.cfg new file mode 100644 index 0000000000..461537b3fb --- /dev/null +++ b/.kokoro/presubmit/system-3.7.cfg @@ -0,0 +1,7 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Only run this nox session. +env_vars: { + key: "NOX_SESSION" + value: "system-3.7" +} \ No newline at end of file diff --git a/owlbot.py b/owlbot.py index f4cf08e0a8..10f5894422 100644 --- a/owlbot.py +++ b/owlbot.py @@ -136,6 +136,7 @@ def update_fixup_scripts(library): system_test_external_dependencies=["pytest-asyncio"], microgenerator=True, cov_level=100, + split_system_tests=True, ) s.move(templated_files) From 1adfc81237c4ddee665e81f1beaef808cddb860e Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Mon, 9 Aug 2021 11:26:57 -0400 Subject: [PATCH 07/19] docs: fixed broken links to devsite (#417) --- google/cloud/firestore_v1/transforms.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/transforms.py b/google/cloud/firestore_v1/transforms.py index e9aa876063..f1361c951f 100644 --- a/google/cloud/firestore_v1/transforms.py +++ b/google/cloud/firestore_v1/transforms.py @@ -72,7 +72,7 @@ class ArrayUnion(_ValueList): """Field transform: appends missing values to an array field. See: - https://cloud.google.com/firestore/docs/reference/rpc/google.cloud.firestore.v1#google.cloud.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.cloud.firestore.v1.ArrayValue.google.cloud.firestore.v1.DocumentTransform.FieldTransform.append_missing_elements + https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.firestore.v1.ArrayValue.google.firestore.v1.DocumentTransform.FieldTransform.append_missing_elements Args: values (List | Tuple): values to append. @@ -83,7 +83,7 @@ class ArrayRemove(_ValueList): """Field transform: remove values from an array field. See: - https://cloud.google.com/firestore/docs/reference/rpc/google.cloud.firestore.v1#google.cloud.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.cloud.firestore.v1.ArrayValue.google.cloud.firestore.v1.DocumentTransform.FieldTransform.remove_all_from_array + https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.firestore.v1.ArrayValue.google.firestore.v1.DocumentTransform.FieldTransform.remove_all_from_array Args: values (List | Tuple): values to remove. @@ -122,7 +122,7 @@ class Increment(_NumericValue): """Field transform: increment a numeric field with specified value. See: - https://cloud.google.com/firestore/docs/reference/rpc/google.cloud.firestore.v1#google.cloud.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.cloud.firestore.v1.ArrayValue.google.cloud.firestore.v1.DocumentTransform.FieldTransform.increment + https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.firestore.v1.ArrayValue.google.firestore.v1.DocumentTransform.FieldTransform.increment Args: value (int | float): value used to increment the field. @@ -133,7 +133,7 @@ class Maximum(_NumericValue): """Field transform: bound numeric field with specified value. See: - https://cloud.google.com/firestore/docs/reference/rpc/google.cloud.firestore.v1#google.cloud.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.cloud.firestore.v1.ArrayValue.google.cloud.firestore.v1.DocumentTransform.FieldTransform.maximum + https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.firestore.v1.ArrayValue.google.firestore.v1.DocumentTransform.FieldTransform.maximum Args: value (int | float): value used to bound the field. @@ -144,7 +144,7 @@ class Minimum(_NumericValue): """Field transform: bound numeric field with specified value. See: - https://cloud.google.com/firestore/docs/reference/rpc/google.cloud.firestore.v1#google.cloud.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.cloud.firestore.v1.ArrayValue.google.cloud.firestore.v1.DocumentTransform.FieldTransform.minimum + https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.DocumentTransform.FieldTransform.FIELDS.google.firestore.v1.ArrayValue.google.firestore.v1.DocumentTransform.FieldTransform.minimum Args: value (int | float): value used to bound the field. From 0176cc7fef8752433b5c2496046d3a56557eb824 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Mon, 9 Aug 2021 13:26:15 -0400 Subject: [PATCH 08/19] docs: added generated docs for Bundles (#416) * docs: added generated docs for Bundles * removed whitespace Co-authored-by: Tres Seaver --- docs/bundles.rst | 6 ++++++ docs/index.rst | 1 + google/cloud/firestore_bundle/bundle.py | 16 ++++++++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 docs/bundles.rst diff --git a/docs/bundles.rst b/docs/bundles.rst new file mode 100644 index 0000000000..92724a3b6b --- /dev/null +++ b/docs/bundles.rst @@ -0,0 +1,6 @@ +Bundles +~~~~~~~ + +.. automodule:: google.cloud.firestore_bundle.bundle + :members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 34002786f1..3fce768ab7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,6 +14,7 @@ API Reference field_path query batch + bundles transaction transforms types diff --git a/google/cloud/firestore_bundle/bundle.py b/google/cloud/firestore_bundle/bundle.py index eae1fa3f4a..73a53aadb5 100644 --- a/google/cloud/firestore_bundle/bundle.py +++ b/google/cloud/firestore_bundle/bundle.py @@ -51,20 +51,22 @@ class FirestoreBundle: Usage: - from google.cloud.firestore import Client + .. code-block:: python + + from google.cloud.firestore import Client, _helpers from google.cloud.firestore_bundle import FirestoreBundle - from google.cloud.firestore import _helpers db = Client() bundle = FirestoreBundle('my-bundle') bundle.add_named_query('all-users', db.collection('users')._query()) bundle.add_named_query( 'top-ten-hamburgers', - db.collection('hamburgers').limit(limit=10)._query(), + db.collection('hamburgers').limit(limit=10), ) serialized: str = bundle.build() - # Store somewhere like your GCS for retrieval by a client SDK. + # Store somewhere like a Google Cloud Storage bucket for retrieval by + # a client SDK. Args: name (str): The Id of the bundle. @@ -88,6 +90,8 @@ def add_document(self, snapshot: DocumentSnapshot) -> "FirestoreBundle": Example: + .. code-block:: python + from google.cloud import firestore db = firestore.Client() @@ -142,6 +146,8 @@ def add_named_query(self, name: str, query: BaseQuery) -> "FirestoreBundle": Example: + .. code-block:: python + from google.cloud import firestore db = firestore.Client() @@ -293,6 +299,8 @@ def build(self) -> str: Example: + .. code-block:: python + from google.cloud import firestore db = firestore.Client() From eb45a36e6c06b642106e061a32bfc119eb7e5bf0 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Wed, 11 Aug 2021 09:02:12 -0400 Subject: [PATCH 09/19] feat: add support for recursive queries (#407) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: added BaseQuery.copy method * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * responded to code review * feat: added recursive query * tidied up * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * more tidying up * fixed error with path compilation * fixed async handling in system tests * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * Update google/cloud/firestore_v1/base_collection.py Co-authored-by: Christopher Wilcox * reverted error message changes * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * comment updates Co-authored-by: Owl Bot Co-authored-by: Christopher Wilcox --- google/cloud/firestore_v1/_helpers.py | 1 + google/cloud/firestore_v1/async_query.py | 18 ++- google/cloud/firestore_v1/base_collection.py | 8 +- google/cloud/firestore_v1/base_query.py | 65 +++++++++- google/cloud/firestore_v1/query.py | 15 ++- tests/system/test_system.py | 121 ++++++++++++++++++ tests/system/test_system_async.py | 125 +++++++++++++++++++ tests/unit/v1/test_async_collection.py | 6 + tests/unit/v1/test_base_query.py | 6 + tests/unit/v1/test_collection.py | 6 + 10 files changed, 367 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index aebdbee477..52d88006cb 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -144,6 +144,7 @@ def verify_path(path, is_collection) -> None: if is_collection: if num_elements % 2 == 0: raise ValueError("A collection must have an odd number of path elements") + else: if num_elements % 2 == 1: raise ValueError("A document must have an even number of path elements") diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index f772194e85..2f94b5f7c9 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -22,6 +22,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore +from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_query import ( BaseCollectionGroup, BaseQuery, @@ -32,7 +33,7 @@ ) from google.cloud.firestore_v1 import async_document -from typing import AsyncGenerator +from typing import AsyncGenerator, Type # Types needed only for Type Hints from google.cloud.firestore_v1.transaction import Transaction @@ -92,6 +93,9 @@ class AsyncQuery(BaseQuery): When false, selects only collections that are immediate children of the `parent` specified in the containing `RunQueryRequest`. When true, selects all descendant collections. + recursive (Optional[bool]): + When true, returns all documents and all documents in any subcollections + below them. Defaults to false. """ def __init__( @@ -106,6 +110,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, + recursive=False, ) -> None: super(AsyncQuery, self).__init__( parent=parent, @@ -118,6 +123,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) async def get( @@ -224,6 +230,14 @@ async def stream( if snapshot is not None: yield snapshot + @staticmethod + def _get_collection_reference_class() -> Type[ + "firestore_v1.async_collection.AsyncCollectionReference" + ]: + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + return AsyncCollectionReference + class AsyncCollectionGroup(AsyncQuery, BaseCollectionGroup): """Represents a Collection Group in the Firestore API. @@ -249,6 +263,7 @@ def __init__( start_at=None, end_at=None, all_descendants=True, + recursive=False, ) -> None: super(AsyncCollectionGroup, self).__init__( parent=parent, @@ -261,6 +276,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) @staticmethod diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index ce31bfb0a3..02363efc2e 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -124,7 +124,10 @@ def document(self, document_id: str = None) -> DocumentReference: if document_id is None: document_id = _auto_id() - child_path = self._path + (document_id,) + # Append `self._path` and the passed document's ID as long as the first + # element in the path is not an empty string, which comes from setting the + # parent to "" for recursive queries. + child_path = self._path + (document_id,) if self._path[0] else (document_id,) return self._client.document(*child_path) def _parent_info(self) -> Tuple[Any, str]: @@ -200,6 +203,9 @@ def list_documents( ]: raise NotImplementedError + def recursive(self) -> "BaseQuery": + return self._query().recursive() + def select(self, field_paths: Iterable[str]) -> BaseQuery: """Create a "select" query with this collection as parent. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 5d11ccb3c0..1812cfca00 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -33,7 +33,17 @@ from google.cloud.firestore_v1.types import Cursor from google.cloud.firestore_v1.types import RunQueryResponse from google.cloud.firestore_v1.order import Order -from typing import Any, Dict, Generator, Iterable, NoReturn, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + NoReturn, + Optional, + Tuple, + Type, + Union, +) # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot @@ -144,6 +154,9 @@ class BaseQuery(object): When false, selects only collections that are immediate children of the `parent` specified in the containing `RunQueryRequest`. When true, selects all descendant collections. + recursive (Optional[bool]): + When true, returns all documents and all documents in any subcollections + below them. Defaults to false. """ ASCENDING = "ASCENDING" @@ -163,6 +176,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, + recursive=False, ) -> None: self._parent = parent self._projection = projection @@ -174,6 +188,7 @@ def __init__( self._start_at = start_at self._end_at = end_at self._all_descendants = all_descendants + self._recursive = recursive def __eq__(self, other): if not isinstance(other, self.__class__): @@ -247,6 +262,7 @@ def _copy( start_at: Optional[Tuple[dict, bool]] = _not_passed, end_at: Optional[Tuple[dict, bool]] = _not_passed, all_descendants: Optional[bool] = _not_passed, + recursive: Optional[bool] = _not_passed, ) -> "BaseQuery": return self.__class__( self._parent, @@ -261,6 +277,7 @@ def _copy( all_descendants=self._evaluate_param( all_descendants, self._all_descendants ), + recursive=self._evaluate_param(recursive, self._recursive), ) def _evaluate_param(self, value, fallback_value): @@ -813,6 +830,46 @@ def stream( def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError + def recursive(self) -> "BaseQuery": + """Returns a copy of this query whose iterator will yield all matching + documents as well as each of their descendent subcollections and documents. + + This differs from the `all_descendents` flag, which only returns descendents + whose subcollection names match the parent collection's name. To return + all descendents, regardless of their subcollection name, use this. + """ + copied = self._copy(recursive=True, all_descendants=True) + if copied._parent and copied._parent.id: + original_collection_id = "/".join(copied._parent._path) + + # Reset the parent to nothing so we can recurse through the entire + # database. This is required to have + # `CollectionSelector.collection_id` not override + # `CollectionSelector.all_descendants`, which happens if both are + # set. + copied._parent = copied._get_collection_reference_class()("") + copied._parent._client = self._parent._client + + # But wait! We don't want to load the entire database; only the + # collection the user originally specified. To accomplish that, we + # add the following arcane filters. + + REFERENCE_NAME_MIN_ID = "__id-9223372036854775808__" + start_at = f"{original_collection_id}/{REFERENCE_NAME_MIN_ID}" + + # The backend interprets this null character is flipping the filter + # to mean the end of the range instead of the beginning. + nullChar = "\0" + end_at = f"{original_collection_id}{nullChar}/{REFERENCE_NAME_MIN_ID}" + + copied = ( + copied.order_by(field_path_module.FieldPath.document_id()) + .start_at({field_path_module.FieldPath.document_id(): start_at}) + .end_at({field_path_module.FieldPath.document_id(): end_at}) + ) + + return copied + def _comparator(self, doc1, doc2) -> int: _orders = self._orders @@ -1073,6 +1130,7 @@ def __init__( start_at=None, end_at=None, all_descendants=True, + recursive=False, ) -> None: if not all_descendants: raise ValueError("all_descendants must be True for collection group query.") @@ -1088,6 +1146,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) def _validate_partition_query(self): @@ -1133,6 +1192,10 @@ def get_partitions( ) -> NoReturn: raise NotImplementedError + @staticmethod + def _get_collection_reference_class() -> Type["BaseCollectionGroup"]: + raise NotImplementedError + class QueryPartition: """Represents a bounded partition of a collection group query. diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index aa2f5ad096..f1e044cbd1 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -19,6 +19,7 @@ a more common way to create a query than direct usage of the constructor. """ +from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -34,7 +35,7 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List +from typing import Any, Callable, Generator, List, Type class Query(BaseQuery): @@ -105,6 +106,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, + recursive=False, ) -> None: super(Query, self).__init__( parent=parent, @@ -117,6 +119,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) def get( @@ -254,6 +257,14 @@ def on_snapshot(docs, changes, read_time): self, callback, document.DocumentSnapshot, document.DocumentReference ) + @staticmethod + def _get_collection_reference_class() -> Type[ + "firestore_v1.collection.CollectionReference" + ]: + from google.cloud.firestore_v1.collection import CollectionReference + + return CollectionReference + class CollectionGroup(Query, BaseCollectionGroup): """Represents a Collection Group in the Firestore API. @@ -279,6 +290,7 @@ def __init__( start_at=None, end_at=None, all_descendants=True, + recursive=False, ) -> None: super(CollectionGroup, self).__init__( parent=parent, @@ -291,6 +303,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) @staticmethod diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 6d4471461c..6e72e65cf3 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1212,6 +1212,127 @@ def test_array_union(client, cleanup): assert doc_ref.get().to_dict() == expected +def test_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Socrates", "favoriteCity": "Athens"}, + "subcollections": { + "pets": [{"name": "Scruffy"}, {"name": "Snowflake"}], + "hobbies": [{"name": "pontificating"}, {"name": "journaling"}], + "philosophers": [{"name": "Aristotle"}, {"name": "Plato"}], + }, + }, + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + { + "data": {"name": "Plato", "favoriteCity": "Corinth"}, + "subcollections": { + "pets": [{"name": "Cuddles"}, {"name": "Sergeant-Puppers"}], + "hobbies": [{"name": "abstraction"}, {"name": "hypotheticals"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}" + ) + ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + ids = [doc.id for doc in db.collection_group("philosophers").recursive().get()] + + expected_ids = [ + # Aristotle doc and subdocs + f"Aristotle{UNIQUE_RESOURCE_ID}", + "meditation", + "questioning-stuff", + "Doggy-Dog", + "Floof-Boy", + # Plato doc and subdocs + f"Plato{UNIQUE_RESOURCE_ID}", + "abstraction", + "hypotheticals", + "Cuddles", + "Sergeant-Puppers", + # Socrates doc and subdocs + f"Socrates{UNIQUE_RESOURCE_ID}", + "journaling", + "pontificating", + "Scruffy", + "Snowflake", + "Aristotle", + "Plato", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + +def test_nested_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}" + ) + ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + aristotle = collection_ref.document(f"Aristotle{UNIQUE_RESOURCE_ID}") + ids = [doc.id for doc in aristotle.collection("pets")._query().recursive().get()] + + expected_ids = [ + # Aristotle pets + "Doggy-Dog", + "Floof-Boy", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + def test_watch_query_order(client, cleanup): db = client collection_ref = db.collection("users") diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 65a46d9841..ef8022f0e7 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1071,6 +1071,131 @@ async def test_batch(client, cleanup): assert not (await document3.get()).exists +async def test_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Socrates", "favoriteCity": "Athens"}, + "subcollections": { + "pets": [{"name": "Scruffy"}, {"name": "Snowflake"}], + "hobbies": [{"name": "pontificating"}, {"name": "journaling"}], + "philosophers": [{"name": "Aristotle"}, {"name": "Plato"}], + }, + }, + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + { + "data": {"name": "Plato", "favoriteCity": "Corinth"}, + "subcollections": { + "pets": [{"name": "Cuddles"}, {"name": "Sergeant-Puppers"}], + "hobbies": [{"name": "abstraction"}, {"name": "hypotheticals"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}-async" + ) + await ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + await inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + ids = [ + doc.id for doc in await db.collection_group("philosophers").recursive().get() + ] + + expected_ids = [ + # Aristotle doc and subdocs + f"Aristotle{UNIQUE_RESOURCE_ID}-async", + "meditation", + "questioning-stuff", + "Doggy-Dog", + "Floof-Boy", + # Plato doc and subdocs + f"Plato{UNIQUE_RESOURCE_ID}-async", + "abstraction", + "hypotheticals", + "Cuddles", + "Sergeant-Puppers", + # Socrates doc and subdocs + f"Socrates{UNIQUE_RESOURCE_ID}-async", + "journaling", + "pontificating", + "Scruffy", + "Snowflake", + "Aristotle", + "Plato", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + +async def test_nested_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}-async" + ) + await ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + await inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + aristotle = collection_ref.document(f"Aristotle{UNIQUE_RESOURCE_ID}-async") + ids = [ + doc.id for doc in await aristotle.collection("pets")._query().recursive().get() + ] + + expected_ids = [ + # Aristotle pets + "Doggy-Dog", + "Floof-Boy", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + async def _chain(*iterators): """Asynchronous reimplementation of `itertools.chain`.""" for iterator in iterators: diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index bf0959e043..33006e2542 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -375,6 +375,12 @@ async def test_stream_with_transaction(self, query_class): query_instance = query_class.return_value query_instance.stream.assert_called_once_with(transaction=transaction) + def test_recursive(self): + from google.cloud.firestore_v1.async_query import AsyncQuery + + col = self._make_one("collection") + self.assertIsInstance(col.recursive(), AsyncQuery) + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index a61aaedb26..3fb9a687f8 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1151,6 +1151,12 @@ def test_comparator_missing_order_by_field_in_data_raises(self): with self.assertRaisesRegex(ValueError, "Can only compare fields "): query._comparator(doc1, doc2) + def test_multiple_recursive_calls(self): + query = self._make_one(_make_client().collection("asdf")) + self.assertIsInstance( + query.recursive().recursive(), type(query), + ) + class Test__enum_from_op_string(unittest.TestCase): @staticmethod diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index feaec81194..5885a29d97 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -349,3 +349,9 @@ def test_on_snapshot(self, watch): collection = self._make_one("collection") collection.on_snapshot(None) watch.for_query.assert_called_once() + + def test_recursive(self): + from google.cloud.firestore_v1.query import Query + + col = self._make_one("collection") + self.assertIsInstance(col.recursive(), Query) From 98a7753f05240a2a75b9ffd42b7a148c65a6e87f Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Wed, 11 Aug 2021 12:41:35 -0400 Subject: [PATCH 10/19] feat: add bulk writer (#396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: bulk writer 555 rate_limiter (#368) * added 555 throttle utility * Update google/cloud/firestore_v1/throttle.py Co-authored-by: Tres Seaver * added ability to request a number of tokens * replaced Callable now parameter with module function * updated tests * renamed throttle -> ramp up * improved docstrings * linting * fixed test coverage * rename to RateLimiter and defer clock to first op * linting Co-authored-by: Tres Seaver * feat: added new batch class for BulkWriter (#397) * feat: added new batch class for BulkWriter * updated docstring to use less colloquial language * feat: BulkWriter implementation (#384) * feat: added `write` method to batch classes * added docstrings to all 3 batch classes instead of just the base * updated batch classes to remove control flag now branches logic via subclasses * fixed broken tests off abstract class * fixed docstring * refactored BulkWriteBatch this commit increases the distance between WriteBatch and BulkWriteBatch * began adding [Async]BulkWriter * continued implementation * working impl or BW * tidied up BW impl * beginning of unit tests for BW * fixed merge problem * initial set of BW unit tests * refactored bulkwriter sending mechanism now consumes off the queue and schedules on the main thread, only going async to actually send * final CI touch ups * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * moved BulkWriter parameters to options format * rebased off master * test fixes Co-authored-by: Owl Bot * feat: add retry support for BulkWriter errors (#413) * parent 0176cc7fef8752433b5c2496046d3a56557eb824 author Craig Labenz 1623693904 -0700 committer Craig Labenz 1628617523 -0400 feat: add retries to bulk-writer * fixed rebase error Co-authored-by: Tres Seaver Co-authored-by: Owl Bot --- google/cloud/firestore_v1/async_client.py | 13 + google/cloud/firestore_v1/base_batch.py | 44 +- google/cloud/firestore_v1/base_client.py | 20 +- google/cloud/firestore_v1/batch.py | 4 +- google/cloud/firestore_v1/bulk_batch.py | 89 ++ google/cloud/firestore_v1/bulk_writer.py | 978 ++++++++++++++++++++++ google/cloud/firestore_v1/rate_limiter.py | 177 ++++ tests/system/test_system.py | 23 + tests/system/test_system_async.py | 23 + tests/unit/v1/_test_helpers.py | 21 + tests/unit/v1/test_async_batch.py | 2 + tests/unit/v1/test_async_client.py | 15 + tests/unit/v1/test_base_batch.py | 16 +- tests/unit/v1/test_batch.py | 3 + tests/unit/v1/test_bulk_batch.py | 105 +++ tests/unit/v1/test_bulk_writer.py | 600 +++++++++++++ tests/unit/v1/test_client.py | 8 + tests/unit/v1/test_rate_limiter.py | 200 +++++ 18 files changed, 2325 insertions(+), 16 deletions(-) create mode 100644 google/cloud/firestore_v1/bulk_batch.py create mode 100644 google/cloud/firestore_v1/bulk_writer.py create mode 100644 google/cloud/firestore_v1/rate_limiter.py create mode 100644 tests/unit/v1/test_bulk_batch.py create mode 100644 tests/unit/v1/test_bulk_writer.py create mode 100644 tests/unit/v1/test_rate_limiter.py diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 8623f640d1..68cb676f2a 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -96,6 +96,19 @@ def __init__( client_options=client_options, ) + def _to_sync_copy(self): + from google.cloud.firestore_v1.client import Client + + if not getattr(self, "_sync_copy", None): + self._sync_copy = Client( + project=self.project, + credentials=self._credentials, + database=self._database, + client_info=self._client_info, + client_options=self._client_options, + ) + return self._sync_copy + @property def _firestore_api(self): """Lazy-loading getter GAPIC Firestore API. diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index 348a6ac454..a4b7ff0bb7 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -14,16 +14,16 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" - -from google.cloud.firestore_v1 import _helpers +import abc +from typing import Dict, Union # Types needed only for Type Hints -from google.cloud.firestore_v1.document import DocumentReference - -from typing import Union +from google.api_core import retry as retries # type: ignore +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference -class BaseWriteBatch(object): +class BaseBatch(metaclass=abc.ABCMeta): """Accumulate write operations to be sent in a batch. This has the same set of methods for write operations that @@ -38,9 +38,16 @@ class BaseWriteBatch(object): def __init__(self, client) -> None: self._client = client self._write_pbs = [] + self._document_references: Dict[str, BaseDocumentReference] = {} self.write_results = None self.commit_time = None + def __len__(self): + return len(self._document_references) + + def __contains__(self, reference: BaseDocumentReference): + return reference._document_path in self._document_references + def _add_write_pbs(self, write_pbs: list) -> None: """Add `Write`` protobufs to this transaction. @@ -52,7 +59,13 @@ def _add_write_pbs(self, write_pbs: list) -> None: """ self._write_pbs.extend(write_pbs) - def create(self, reference: DocumentReference, document_data: dict) -> None: + @abc.abstractmethod + def commit(self): + """Sends all accumulated write operations to the server. The details of this + write depend on the implementing class.""" + raise NotImplementedError() + + def create(self, reference: BaseDocumentReference, document_data: dict) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -65,11 +78,12 @@ def create(self, reference: DocumentReference, document_data: dict) -> None: creating a document. """ write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) + self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) def set( self, - reference: DocumentReference, + reference: BaseDocumentReference, document_data: dict, merge: Union[bool, list] = False, ) -> None: @@ -98,11 +112,12 @@ def set( reference._document_path, document_data ) + self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) def update( self, - reference: DocumentReference, + reference: BaseDocumentReference, field_updates: dict, option: _helpers.WriteOption = None, ) -> None: @@ -126,10 +141,11 @@ def update( write_pbs = _helpers.pbs_for_update( reference._document_path, field_updates, option ) + self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) def delete( - self, reference: DocumentReference, option: _helpers.WriteOption = None + self, reference: BaseDocumentReference, option: _helpers.WriteOption = None ) -> None: """Add a "change" to delete a document. @@ -146,9 +162,15 @@ def delete( state of the document before applying changes. """ write_pb = _helpers.pb_for_delete(reference._document_path, option) + self._document_references[reference._document_path] = reference self._add_write_pbs([write_pb]) - def _prep_commit(self, retry, timeout): + +class BaseWriteBatch(BaseBatch): + """Base class for a/sync implementations of the `commit` RPC. `commit` is useful + for lower volumes or when the order of write operations is important.""" + + def _prep_commit(self, retry: retries.Retry, timeout: float): """Shared setup for async/sync :meth:`commit`.""" request = { "database": self._client._database_string, diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 7eb5c26b08..e68031ed4d 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,7 +37,10 @@ from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.base_document import DocumentSnapshot - +from google.cloud.firestore_v1.bulk_writer import ( + BulkWriter, + BulkWriterOptions, +) from google.cloud.firestore_v1.field_path import render_field_path from typing import ( Any, @@ -278,6 +281,21 @@ def _get_collection_reference(self, collection_id: str) -> BaseCollectionReferen def document(self, *document_path) -> BaseDocumentReference: raise NotImplementedError + def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter: + """Get a BulkWriter instance from this client. + + Args: + :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriterOptions`: + Optional control parameters for the + :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter` returned. + + Returns: + :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter`: + A utility to efficiently create and save many `WriteBatch` instances + to the server. + """ + return BulkWriter(client=self, options=options) + def _document_path_helper(self, *document_path) -> List[str]: """Standardize the format of path to tuple of path segments and strip the database string from path if present. diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 1758051228..a7ad074ba5 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -21,7 +21,9 @@ class WriteBatch(BaseWriteBatch): - """Accumulate write operations to be sent in a batch. + """Accumulate write operations to be sent in a batch. Use this over + `BulkWriteBatch` for lower volumes or when the order of operations + within a given batch is important. This has the same set of methods for write operations that :class:`~google.cloud.firestore_v1.document.DocumentReference` does, diff --git a/google/cloud/firestore_v1/bulk_batch.py b/google/cloud/firestore_v1/bulk_batch.py new file mode 100644 index 0000000000..bc2f75a38b --- /dev/null +++ b/google/cloud/firestore_v1/bulk_batch.py @@ -0,0 +1,89 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for batch requests to the Google Cloud Firestore API.""" +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_batch import BaseBatch +from google.cloud.firestore_v1.types.firestore import BatchWriteResponse + + +class BulkWriteBatch(BaseBatch): + """Accumulate write operations to be sent in a batch. Use this over + `WriteBatch` for higher volumes (e.g., via `BulkWriter`) and when the order + of operations within a given batch is unimportant. + + Because the order in which individual write operations are applied to the database + is not guaranteed, `batch_write` RPCs can never contain multiple operations + to the same document. If calling code detects a second write operation to a + known document reference, it should first cut off the previous batch and + send it, then create a new batch starting with the latest write operation. + In practice, the [Async]BulkWriter classes handle this. + + This has the same set of methods for write operations that + :class:`~google.cloud.firestore_v1.document.DocumentReference` does, + e.g. :meth:`~google.cloud.firestore_v1.document.DocumentReference.create`. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this batch. + """ + + def __init__(self, client) -> None: + super(BulkWriteBatch, self).__init__(client=client) + + def commit( + self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + ) -> BatchWriteResponse: + """Writes the changes accumulated in this batch. + + Write operations are not guaranteed to be applied in order and must not + contain multiple writes to any given document. Preferred over `commit` + for performance reasons if these conditions are acceptable. + + Args: + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + :class:`google.cloud.proto.firestore.v1.write.BatchWriteResponse`: + Container holding the write results corresponding to the changes + committed, returned in the same order as the changes were applied to + this batch. An individual write result contains an ``update_time`` + field. + """ + request, kwargs = self._prep_commit(retry, timeout) + + _api = self._client._firestore_api + save_response: BatchWriteResponse = _api.batch_write( + request=request, metadata=self._client._rpc_metadata, **kwargs, + ) + + self._write_pbs = [] + self.write_results = list(save_response.write_results) + + return save_response + + def _prep_commit(self, retry: retries.Retry, timeout: float): + request = { + "database": self._client._database_string, + "writes": self._write_pbs, + "labels": None, + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + return request, kwargs diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py new file mode 100644 index 0000000000..ad886f81d3 --- /dev/null +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -0,0 +1,978 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for efficiently writing large amounts of data to the Google Cloud +Firestore API.""" + +import bisect +import collections +import concurrent.futures +import datetime +import enum +import functools +import logging +import time + +from typing import Callable, Dict, List, Optional, Union, TYPE_CHECKING + +from google.rpc import status_pb2 # type: ignore + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch +from google.cloud.firestore_v1.rate_limiter import RateLimiter +from google.cloud.firestore_v1.types.firestore import BatchWriteResponse +from google.cloud.firestore_v1.types.write import WriteResult + +if TYPE_CHECKING: + from google.cloud.firestore_v1.base_client import BaseClient # pragma: NO COVER + + +logger = logging.getLogger(__name__) + + +class BulkRetry(enum.Enum): + """Indicator for what retry strategy the BulkWriter should use.""" + + # Common exponential backoff algorithm. This strategy is largely incompatible + # with the default retry limit of 15, so use with caution. + exponential = enum.auto() + + # Default strategy that adds 1 second of delay per retry. + linear = enum.auto() + + # Immediate retries with no growing delays. + immediate = enum.auto() + + +class SendMode(enum.Enum): + """Indicator for whether a BulkWriter should commit batches in the main + thread or hand that work off to an executor.""" + + # Default strategy that parallelizes network I/O on an executor. You almost + # certainly want this. + parallel = enum.auto() + + # Alternate strategy which blocks during all network I/O. Much slower, but + # assures all batches are sent to the server in order. Note that + # `SendMode.serial` is extremely susceptible to slowdowns from retries if + # there are a lot of errors. + serial = enum.auto() + + +class AsyncBulkWriterMixin: + """ + Mixin which contains the methods on `BulkWriter` which must only be + submitted to the executor (or called by functions submitted to the executor). + This mixin exists purely for organization and clarity of implementation + (e.g., there is no metaclass magic). + + The entrypoint to the parallelizable code path is `_send_batch()`, which is + wrapped in a decorator which ensures that the `SendMode` is honored. + """ + + def _with_send_mode(fn): + """Decorates a method to ensure it is only called via the executor + (IFF the SendMode value is SendMode.parallel!). + + Usage: + + @_with_send_mode + def my_method(self): + parallel_stuff() + + def something_else(self): + # Because of the decorator around `my_method`, the following + # method invocation: + self.my_method() + # becomes equivalent to `self._executor.submit(self.my_method)` + # when the send mode is `SendMode.parallel`. + + Use on entrypoint methods for code paths that *must* be parallelized. + """ + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if self._send_mode == SendMode.parallel: + return self._executor.submit(lambda: fn(self, *args, **kwargs)) + else: + # For code parity, even `SendMode.serial` scenarios should return + # a future here. Anything else would badly complicate calling code. + result = fn(self, *args, **kwargs) + future = concurrent.futures.Future() + future.set_result(result) + return future + + return wrapper + + @_with_send_mode + def _send_batch( + self, batch: BulkWriteBatch, operations: List["BulkWriterOperation"] + ): + """Sends a batch without regard to rate limits, meaning limits must have + already been checked. To that end, do not call this directly; instead, + call `_send_until_queue_is_empty`. + + Args: + batch(:class:`~google.cloud.firestore_v1.base_batch.BulkWriteBatch`) + """ + _len_batch: int = len(batch) + self._in_flight_documents += _len_batch + response: BatchWriteResponse = self._send(batch) + self._in_flight_documents -= _len_batch + + # Update bookkeeping totals + self._total_batches_sent += 1 + self._total_write_operations += _len_batch + + self._process_response(batch, response, operations) + + def _process_response( + self, + batch: BulkWriteBatch, + response: BatchWriteResponse, + operations: List["BulkWriterOperation"], + ) -> None: + """Invokes submitted callbacks for each batch and each operation within + each batch. As this is called from `_send_batch()`, this is parallelized + if we are in that mode. + """ + batch_references: List[BaseDocumentReference] = list( + batch._document_references.values(), + ) + self._batch_callback(batch, response, self) + + status: status_pb2.Status + for index, status in enumerate(response.status): + if status.code == 0: + self._success_callback( + # DocumentReference + batch_references[index], + # WriteResult + response.write_results[index], + # BulkWriter + self, + ) + else: + operation: BulkWriterOperation = operations[index] + should_retry: bool = self._error_callback( + # BulkWriteFailure + BulkWriteFailure( + operation=operation, code=status.code, message=status.message, + ), + # BulkWriter + self, + ) + if should_retry: + operation.attempts += 1 + self._retry_operation(operation) + + def _retry_operation( + self, operation: "BulkWriterOperation", + ) -> concurrent.futures.Future: + + delay: int = 0 + if self._options.retry == BulkRetry.exponential: + delay = operation.attempts ** 2 # pragma: NO COVER + elif self._options.retry == BulkRetry.linear: + delay = operation.attempts + + run_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=delay) + + # Use of `bisect.insort` maintains the requirement that `self._retries` + # always remain sorted by each object's `run_at` time. Note that it is + # able to do this because `OperationRetry` instances are entirely sortable + # by their `run_at` value. + bisect.insort( + self._retries, OperationRetry(operation=operation, run_at=run_at), + ) + + def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: + """Hook for overwriting the sending of batches. As this is only called + from `_send_batch()`, this is parallelized if we are in that mode. + """ + return batch.commit() # pragma: NO COVER + + +class BulkWriter(AsyncBulkWriterMixin): + """ + Accumulate and efficiently save large amounts of document write operations + to the server. + + BulkWriter can handle large data migrations or updates, buffering records + in memory and submitting them to the server in batches of 20. + + The submission of batches is internally parallelized with a ThreadPoolExecutor, + meaning end developers do not need to manage an event loop or worry about asyncio + to see parallelization speed ups (which can easily 10x throughput). Because + of this, there is no companion `AsyncBulkWriter` class, as is usually seen + with other utility classes. + + Usage: + + # Instantiate the BulkWriter. This works from either `Client` or + # `AsyncClient`. + db = firestore.Client() + bulk_writer = db.bulk_writer() + + # Attach an optional success listener to be called once per document. + bulk_writer.on_write_result( + lambda reference, result, bulk_writer: print(f'Saved {reference._document_path}') + ) + + # Queue an arbitrary amount of write operations. + # Assume `my_new_records` is a list of (DocumentReference, dict,) + # tuple-pairs that you supply. + + reference: DocumentReference + data: dict + for reference, data in my_new_records: + bulk_writer.create(reference, data) + + # Block until all pooled writes are complete. + bulk_writer.flush() + + Args: + client(:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this BulkWriter. + """ + + batch_size: int = 20 + + def __init__( + self, + client: Optional["BaseClient"] = None, + options: Optional["BulkWriterOptions"] = None, + ): + # Because `BulkWriter` instances are all synchronous/blocking on the + # main thread (instead using other threads for asynchrony), it is + # incompatible with AsyncClient's various methods that return Futures. + # `BulkWriter` parallelizes all of its network I/O without the developer + # having to worry about awaiting async methods, so we must convert an + # AsyncClient instance into a plain Client instance. + self._client = ( + client._to_sync_copy() if type(client).__name__ == "AsyncClient" else client + ) + self._options = options or BulkWriterOptions() + self._send_mode = self._options.mode + + self._operations: List[BulkWriterOperation] + # List of the `_document_path` attribute for each DocumentReference + # contained in the current `self._operations`. This is reset every time + # `self._operations` is reset. + self._operations_document_paths: List[BaseDocumentReference] + self._reset_operations() + + # List of all `BulkWriterOperation` objects that are waiting to be retried. + # Each such object is wrapped in an `OperationRetry` object which pairs + # the raw operation with the `datetime` of its next scheduled attempt. + # `self._retries` must always remain sorted for efficient reads, so it is + # required to only ever add elements via `bisect.insort`. + self._retries: collections.deque["OperationRetry"] = collections.deque([]) + + self._queued_batches = collections.deque([]) + self._is_open: bool = True + + # This list will go on to store the future returned from each submission + # to the executor, for the purpose of awaiting all of those futures' + # completions in the `flush` method. + self._pending_batch_futures: List[concurrent.futures.Future] = [] + + self._success_callback: Callable[ + [BaseDocumentReference, WriteResult, "BulkWriter"], None + ] = BulkWriter._default_on_success + self._batch_callback: Callable[ + [BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None + ] = BulkWriter._default_on_batch + self._error_callback: Callable[ + [BulkWriteFailure, BulkWriter], bool + ] = BulkWriter._default_on_error + + self._in_flight_documents: int = 0 + self._rate_limiter = RateLimiter( + initial_tokens=self._options.initial_ops_per_second, + global_max_tokens=self._options.max_ops_per_second, + ) + + # Keep track of progress as batches and write operations are completed + self._total_batches_sent: int = 0 + self._total_write_operations: int = 0 + + self._ensure_executor() + + @staticmethod + def _default_on_batch( + batch: BulkWriteBatch, response: BatchWriteResponse, bulk_writer: "BulkWriter", + ) -> None: + pass + + @staticmethod + def _default_on_success( + reference: BaseDocumentReference, + result: WriteResult, + bulk_writer: "BulkWriter", + ) -> None: + pass + + @staticmethod + def _default_on_error(error: "BulkWriteFailure", bulk_writer: "BulkWriter") -> bool: + # Default number of retries for each operation is 15. This is a scary + # number to combine with an exponential backoff, and as such, our default + # backoff strategy is linear instead of exponential. + return error.attempts < 15 + + def _reset_operations(self) -> None: + self._operations = [] + self._operations_document_paths = [] + + def _ensure_executor(self): + """Reboots the executor used to send batches if it has been shutdown.""" + if getattr(self, "_executor", None) is None or self._executor._shutdown: + self._executor = self._instantiate_executor() + + def _ensure_sending(self): + self._ensure_executor() + self._send_until_queue_is_empty() + + def _instantiate_executor(self): + return concurrent.futures.ThreadPoolExecutor() + + def flush(self): + """ + Block until all pooled write operations are complete and then resume + accepting new write operations. + """ + # Calling `flush` consecutively is a no-op. + if self._executor._shutdown: + return + + while True: + + # Queue any waiting operations and try our luck again. + # This can happen if users add a number of records not divisible by + # 20 and then call flush (which should be ~19 out of 20 use cases). + # Execution will arrive here and find the leftover operations that + # never filled up a batch organically, and so we must send them here. + if self._operations: + self._enqueue_current_batch() + continue + + # If we find queued but unsent batches or pending retries, begin + # sending immediately. Note that if we are waiting on retries, but + # they have longer to wait as specified by the retry backoff strategy, + # we may have to make several passes through this part of the loop. + # (This is related to the sleep and its explanation below.) + if self._queued_batches or self._retries: + self._ensure_sending() + + # This sleep prevents max-speed laps through this loop, which can + # and will happen if the BulkWriter is doing nothing except waiting + # on retries to be ready to re-send. Removing this sleep will cause + # whatever thread is running this code to sit near 100% CPU until + # all retries are abandoned or successfully resolved. + time.sleep(0.1) + continue + + # We store the executor's Future from each batch send operation, so + # the first pass through here, we are guaranteed to find "pending" + # batch futures and have to wait. However, the second pass through + # will be fast unless the last batch introduced more retries. + if self._pending_batch_futures: + _batches = self._pending_batch_futures + self._pending_batch_futures = [] + concurrent.futures.wait(_batches) + + # Continuing is critical here (as opposed to breaking) because + # the final batch may have introduced retries which is most + # straightforwardly verified by heading back to the top of the loop. + continue + + break + + # We no longer expect to have any queued batches or pending futures, + # so the executor can be shutdown. + self._executor.shutdown() + + def close(self): + """ + Block until all pooled write operations are complete and then reject + any further write operations. + """ + self._is_open = False + self.flush() + + def _maybe_enqueue_current_batch(self): + """ + Checks to see whether the in-progress batch is full and, if it is, + adds it to the sending queue. + """ + if len(self._operations) >= self.batch_size: + self._enqueue_current_batch() + + def _enqueue_current_batch(self): + """Adds the current batch to the back of the sending line, resets the + list of queued ops, and begins the process of actually sending whatever + batch is in the front of the line, which will often be a different batch. + """ + # Put our batch in the back of the sending line + self._queued_batches.append(self._operations) + + # Reset the local store of operations + self._reset_operations() + + # The sending loop powers off upon reaching the end of the queue, so + # here we make sure that is running. + self._ensure_sending() + + def _send_until_queue_is_empty(self): + """First domino in the sending codepath. This does not need to be + parallelized for two reasons: + + 1) Putting this on a worker thread could lead to two running in parallel + and thus unpredictable commit ordering or failure to adhere to + rate limits. + 2) This method only blocks when `self._request_send()` does not immediately + return, and in that case, the BulkWriter's ramp-up / throttling logic + has determined that it is attempting to exceed the maximum write speed, + and so parallelizing this method would not increase performance anyway. + + Once `self._request_send()` returns, this method calls `self._send_batch()`, + which parallelizes itself if that is our SendMode value. + + And once `self._send_batch()` is called (which does not block if we are + sending in parallel), jumps back to the top and re-checks for any queued + batches. + + Note that for sufficiently large data migrations, this can block the + submission of additional write operations (e.g., the CRUD methods); + but again, that is only if the maximum write speed is being exceeded, + and thus this scenario does not actually further reduce performance. + """ + self._schedule_ready_retries() + + while self._queued_batches: + + # For FIFO order, add to the right of this deque (via `append`) and take + # from the left (via `popleft`). + operations: List[BulkWriterOperation] = self._queued_batches.popleft() + + # Block until we are cleared for takeoff, which is fine because this + # returns instantly unless the rate limiting logic determines that we + # are attempting to exceed the maximum write speed. + self._request_send(len(operations)) + + # Handle some bookkeeping, and ultimately put these bits on the wire. + batch = BulkWriteBatch(client=self._client) + op: BulkWriterOperation + for op in operations: + op.add_to_batch(batch) + + # `_send_batch` is optionally parallelized by `@_with_send_mode`. + future = self._send_batch(batch=batch, operations=operations) + self._pending_batch_futures.append(future) + + self._schedule_ready_retries() + + def _schedule_ready_retries(self): + """Grabs all ready retries and re-queues them.""" + + # Because `self._retries` always exists in a sorted state (thanks to only + # ever adding to it via `bisect.insort`), and because `OperationRetry` + # objects are comparable against `datetime` objects, this bisect functionally + # returns the number of retires that are ready for immediate reenlistment. + take_until_index = bisect.bisect(self._retries, datetime.datetime.utcnow()) + + for _ in range(take_until_index): + retry: OperationRetry = self._retries.popleft() + retry.retry(self) + + def _request_send(self, batch_size: int) -> bool: + # Set up this boolean to avoid repeatedly taking tokens if we're only + # waiting on the `max_in_flight` limit. + have_received_tokens: bool = False + + while True: + # To avoid bottlenecks on the server, an additional limit is that no + # more write operations can be "in flight" (sent but still awaiting + # response) at any given point than the maximum number of writes per + # second. + under_threshold: bool = ( + self._in_flight_documents <= self._rate_limiter._maximum_tokens + ) + # Ask for tokens each pass through this loop until they are granted, + # and then stop. + have_received_tokens = ( + have_received_tokens or self._rate_limiter.take_tokens(batch_size) + ) + if not under_threshold or not have_received_tokens: + # Try again until both checks are true. + # Note that this sleep is helpful to prevent the main BulkWriter + # thread from spinning through this loop as fast as possible and + # pointlessly burning CPU while we wait for the arrival of a + # fixed moment in the future. + time.sleep(0.01) + continue + + return True + + def create( + self, reference: BaseDocumentReference, document_data: Dict, attempts: int = 0, + ) -> None: + """Adds a `create` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this create operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + document_data (dict): + Raw data to save to the server. + """ + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterCreateOperation( + reference=reference, document_data=document_data, attempts=attempts, + ), + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def delete( + self, + reference: BaseDocumentReference, + option: Optional[_helpers.WriteOption] = None, + attempts: int = 0, + ) -> None: + """Adds a `delete` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this delete operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + option (:class:`~google.cloud.firestore_v1._helpers.WriteOption`): + Optional flag to modify the nature of this write. + """ + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterDeleteOperation( + reference=reference, option=option, attempts=attempts, + ), + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def set( + self, + reference: BaseDocumentReference, + document_data: Dict, + merge: Union[bool, list] = False, + attempts: int = 0, + ) -> None: + """Adds a `set` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this set operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + document_data (dict): + Raw data to save to the server. + merge (bool): + Whether or not to completely overwrite any existing data with + the supplied data. + """ + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterSetOperation( + reference=reference, + document_data=document_data, + merge=merge, + attempts=attempts, + ) + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def update( + self, + reference: BaseDocumentReference, + field_updates: dict, + option: Optional[_helpers.WriteOption] = None, + attempts: int = 0, + ) -> None: + """Adds an `update` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this update operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + field_updates (dict): + Key paths to specific nested data that should be upated. + option (:class:`~google.cloud.firestore_v1._helpers.WriteOption`): + Optional flag to modify the nature of this write. + """ + # This check is copied from other Firestore classes for the purposes of + # surfacing the error immediately. + if option.__class__.__name__ == "ExistsOption": + raise ValueError("you must not pass an explicit write option to update.") + + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterUpdateOperation( + reference=reference, + field_updates=field_updates, + option=option, + attempts=attempts, + ) + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def on_write_result( + self, + callback: Callable[[BaseDocumentReference, WriteResult, "BulkWriter"], None], + ) -> None: + """Sets a callback that will be invoked once for every successful operation.""" + self._success_callback = callback or BulkWriter._default_on_success + + def on_batch_result( + self, + callback: Callable[[BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None], + ) -> None: + """Sets a callback that will be invoked once for every successful batch.""" + self._batch_callback = callback or BulkWriter._default_on_batch + + def on_write_error( + self, callback: Callable[["BulkWriteFailure", "BulkWriter"], bool] + ) -> None: + """Sets a callback that will be invoked once for every batch that contains + an error.""" + self._error_callback = callback or BulkWriter._default_on_error + + def _verify_not_closed(self): + if not self._is_open: + raise Exception("BulkWriter is closed and cannot accept new operations") + + +class BulkWriterOperation: + """Parent class for all operation container classes. + + `BulkWriterOperation` exists to house all the necessary information for a + specific write task, including meta information like the current number of + attempts. If a write fails, it is its wrapper `BulkWriteOperation` class + that ferries it into its next retry without getting confused with other + similar writes to the same document. + """ + + def add_to_batch(self, batch: BulkWriteBatch): + """Adds `self` to the supplied batch.""" + assert isinstance(batch, BulkWriteBatch) + if isinstance(self, BulkWriterCreateOperation): + return batch.create( + reference=self.reference, document_data=self.document_data, + ) + + if isinstance(self, BulkWriterDeleteOperation): + return batch.delete(reference=self.reference, option=self.option,) + + if isinstance(self, BulkWriterSetOperation): + return batch.set( + reference=self.reference, + document_data=self.document_data, + merge=self.merge, + ) + + if isinstance(self, BulkWriterUpdateOperation): + return batch.update( + reference=self.reference, + field_updates=self.field_updates, + option=self.option, + ) + raise TypeError( + f"Unexpected type of {self.__class__.__name__} for batch" + ) # pragma: NO COVER + + +@functools.total_ordering +class BaseOperationRetry: + """Parent class for both the @dataclass and old-style `OperationRetry` + classes. + + Methods on this class be moved directly to `OperationRetry` when support for + Python 3.6 is dropped and `dataclasses` becomes universal. + """ + + def __lt__(self, other: "OperationRetry"): + """Allows use of `bisect` to maintain a sorted list of `OperationRetry` + instances, which in turn allows us to cheaply grab all that are ready to + run.""" + if isinstance(other, OperationRetry): + return self.run_at < other.run_at + elif isinstance(other, datetime.datetime): + return self.run_at < other + return NotImplemented # pragma: NO COVER + + def retry(self, bulk_writer: BulkWriter) -> None: + """Call this after waiting any necessary time to re-add the enclosed + operation to the supplied BulkWriter's internal queue.""" + if isinstance(self.operation, BulkWriterCreateOperation): + bulk_writer.create( + reference=self.operation.reference, + document_data=self.operation.document_data, + attempts=self.operation.attempts, + ) + + elif isinstance(self.operation, BulkWriterDeleteOperation): + bulk_writer.delete( + reference=self.operation.reference, + option=self.operation.option, + attempts=self.operation.attempts, + ) + + elif isinstance(self.operation, BulkWriterSetOperation): + bulk_writer.set( + reference=self.operation.reference, + document_data=self.operation.document_data, + merge=self.operation.merge, + attempts=self.operation.attempts, + ) + + elif isinstance(self.operation, BulkWriterUpdateOperation): + bulk_writer.update( + reference=self.operation.reference, + field_updates=self.operation.field_updates, + option=self.operation.option, + attempts=self.operation.attempts, + ) + else: + raise TypeError( + f"Unexpected type of {self.operation.__class__.__name__} for OperationRetry.retry" + ) # pragma: NO COVER + + +try: + from dataclasses import dataclass + + @dataclass + class BulkWriterOptions: + initial_ops_per_second: int = 500 + max_ops_per_second: int = 500 + mode: SendMode = SendMode.parallel + retry: BulkRetry = BulkRetry.linear + + @dataclass + class BulkWriteFailure: + operation: BulkWriterOperation + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html + code: int + message: str + + @property + def attempts(self) -> int: + return self.operation.attempts + + @dataclass + class OperationRetry(BaseOperationRetry): + """Container for an additional attempt at an operation, scheduled for + the future.""" + + operation: BulkWriterOperation + run_at: datetime.datetime + + @dataclass + class BulkWriterCreateOperation(BulkWriterOperation): + """Container for BulkWriter.create() operations.""" + + reference: BaseDocumentReference + document_data: Dict + attempts: int = 0 + + @dataclass + class BulkWriterUpdateOperation(BulkWriterOperation): + """Container for BulkWriter.update() operations.""" + + reference: BaseDocumentReference + field_updates: Dict + option: Optional[_helpers.WriteOption] + attempts: int = 0 + + @dataclass + class BulkWriterSetOperation(BulkWriterOperation): + """Container for BulkWriter.set() operations.""" + + reference: BaseDocumentReference + document_data: Dict + merge: Union[bool, list] = False + attempts: int = 0 + + @dataclass + class BulkWriterDeleteOperation(BulkWriterOperation): + """Container for BulkWriter.delete() operations.""" + + reference: BaseDocumentReference + option: Optional[_helpers.WriteOption] + attempts: int = 0 + + +except ImportError: + + # Note: When support for Python 3.6 is dropped and `dataclasses` is reliably + # in the stdlib, this entire section can be dropped in favor of the dataclass + # versions above. Additonally, the methods on `BaseOperationRetry` can be added + # directly to `OperationRetry` and `BaseOperationRetry` can be deleted. + + class BulkWriterOptions: + def __init__( + self, + initial_ops_per_second: int = 500, + max_ops_per_second: int = 500, + mode: SendMode = SendMode.parallel, + retry: BulkRetry = BulkRetry.linear, + ): + self.initial_ops_per_second = initial_ops_per_second + self.max_ops_per_second = max_ops_per_second + self.mode = mode + self.retry = retry + + class BulkWriteFailure: + def __init__( + self, + operation: BulkWriterOperation, + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html + code: int, + message: str, + ): + self.operation = operation + self.code = code + self.message = message + + @property + def attempts(self) -> int: + return self.operation.attempts + + class OperationRetry(BaseOperationRetry): + """Container for an additional attempt at an operation, scheduled for + the future.""" + + def __init__( + self, operation: BulkWriterOperation, run_at: datetime.datetime, + ): + self.operation = operation + self.run_at = run_at + + class BulkWriterCreateOperation(BulkWriterOperation): + """Container for BulkWriter.create() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + document_data: Dict, + attempts: int = 0, + ): + self.reference = reference + self.document_data = document_data + self.attempts = attempts + + class BulkWriterUpdateOperation(BulkWriterOperation): + """Container for BulkWriter.update() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + field_updates: Dict, + option: Optional[_helpers.WriteOption], + attempts: int = 0, + ): + self.reference = reference + self.field_updates = field_updates + self.option = option + self.attempts = attempts + + class BulkWriterSetOperation(BulkWriterOperation): + """Container for BulkWriter.set() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + document_data: Dict, + merge: Union[bool, list] = False, + attempts: int = 0, + ): + self.reference = reference + self.document_data = document_data + self.merge = merge + self.attempts = attempts + + class BulkWriterDeleteOperation(BulkWriterOperation): + """Container for BulkWriter.delete() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + option: Optional[_helpers.WriteOption], + attempts: int = 0, + ): + self.reference = reference + self.option = option + self.attempts = attempts diff --git a/google/cloud/firestore_v1/rate_limiter.py b/google/cloud/firestore_v1/rate_limiter.py new file mode 100644 index 0000000000..ee920edae0 --- /dev/null +++ b/google/cloud/firestore_v1/rate_limiter.py @@ -0,0 +1,177 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from typing import NoReturn, Optional + + +def utcnow(): + return datetime.datetime.utcnow() + + +default_initial_tokens: int = 500 +default_phase_length: int = 60 * 5 # 5 minutes +microseconds_per_second: int = 1000000 + + +class RateLimiter: + """Implements 5/5/5 ramp-up via Token Bucket algorithm. + + 5/5/5 is a ramp up strategy that starts with a budget of 500 operations per + second. Additionally, every 5 minutes, the maximum budget can increase by + 50%. Thus, at 5:01 into a long bulk-writing process, the maximum budget + becomes 750 operations per second. At 10:01, the budget becomes 1,125 + operations per second. + + The Token Bucket algorithm uses the metaphor of a bucket, or pile, or really + any container, if we're being honest, of tokens from which a user is able + to draw. If there are tokens available, you can do the thing. If there are not, + you can not do the thing. Additionally, tokens replenish at a fixed rate. + + Usage: + + rate_limiter = RateLimiter() + tokens = rate_limiter.take_tokens(20) + + if not tokens: + queue_retry() + else: + for _ in range(tokens): + my_operation() + + Args: + initial_tokens (Optional[int]): Starting size of the budget. Defaults + to 500. + phase_length (Optional[int]): Number of seconds, after which, the size + of the budget can increase by 50%. Such an increase will happen every + [phase_length] seconds if operation requests continue consistently. + """ + + def __init__( + self, + initial_tokens: int = default_initial_tokens, + global_max_tokens: Optional[int] = None, + phase_length: int = default_phase_length, + ): + # Tracks the volume of operations during a given ramp-up phase. + self._operations_this_phase: int = 0 + + # If provided, this enforces a cap on the maximum number of writes per + # second we can ever attempt, regardless of how many 50% increases the + # 5/5/5 rule would grant. + self._global_max_tokens = global_max_tokens + + self._start: Optional[datetime.datetime] = None + self._last_refill: Optional[datetime.datetime] = None + + # Current number of available operations. Decrements with every + # permitted request and refills over time. + self._available_tokens: int = initial_tokens + + # Maximum size of the available operations. Can increase by 50% + # every [phase_length] number of seconds. + self._maximum_tokens: int = self._available_tokens + + if self._global_max_tokens is not None: + self._available_tokens = min( + self._available_tokens, self._global_max_tokens + ) + self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) + + # Number of seconds after which the [_maximum_tokens] can increase by 50%. + self._phase_length: int = phase_length + + # Tracks how many times the [_maximum_tokens] has increased by 50%. + self._phase: int = 0 + + def _start_clock(self): + self._start = self._start or utcnow() + self._last_refill = self._last_refill or utcnow() + + def take_tokens(self, num: Optional[int] = 1, allow_less: bool = False) -> int: + """Returns the number of available tokens, up to the amount requested.""" + self._start_clock() + self._check_phase() + self._refill() + + minimum_tokens = 1 if allow_less else num + + if self._available_tokens >= minimum_tokens: + _num_to_take = min(self._available_tokens, num) + self._available_tokens -= _num_to_take + self._operations_this_phase += _num_to_take + return _num_to_take + return 0 + + def _check_phase(self): + """Increments or decrements [_phase] depending on traffic. + + Every [_phase_length] seconds, if > 50% of available traffic was used + during the window, increases [_phase], otherwise, decreases [_phase]. + + This is a no-op unless a new [_phase_length] number of seconds since the + start was crossed since it was last called. + """ + age: datetime.timedelta = utcnow() - self._start + + # Uses integer division to calculate the expected phase. We start in + # Phase 0, so until [_phase_length] seconds have passed, this will + # not resolve to 1. + expected_phase: int = age.seconds // self._phase_length + + # Short-circuit if we are still in the expected phase. + if expected_phase == self._phase: + return + + operations_last_phase: int = self._operations_this_phase + self._operations_this_phase = 0 + + previous_phase: int = self._phase + self._phase = expected_phase + + # No-op if we did nothing for an entire phase + if operations_last_phase and self._phase > previous_phase: + self._increase_maximum_tokens() + + def _increase_maximum_tokens(self) -> NoReturn: + self._maximum_tokens = round(self._maximum_tokens * 1.5) + if self._global_max_tokens is not None: + self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) + + def _refill(self) -> NoReturn: + """Replenishes any tokens that should have regenerated since the last + operation.""" + now: datetime.datetime = utcnow() + time_since_last_refill: datetime.timedelta = now - self._last_refill + + if time_since_last_refill: + self._last_refill = now + + # If we haven't done anything for 1s, then we know for certain we + # should reset to max capacity. + if time_since_last_refill.seconds >= 1: + self._available_tokens = self._maximum_tokens + + # If we have done something in the last 1s, then we know we should + # allocate proportional tokens. + else: + _percent_of_max: float = ( + time_since_last_refill.microseconds / microseconds_per_second + ) + new_tokens: int = round(_percent_of_max * self._maximum_tokens) + + # Add the number of provisioned tokens, capped at the maximum size. + self._available_tokens = min( + self._maximum_tokens, self._available_tokens + new_tokens, + ) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 6e72e65cf3..0975a73d09 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1075,6 +1075,29 @@ def test_batch(client, cleanup): assert not document3.get().exists +def test_live_bulk_writer(client, cleanup): + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + db: Client = client + bw: BulkWriter = db.bulk_writer() + col = db.collection(f"bulkitems{UNIQUE_RESOURCE_ID}") + + for index in range(50): + doc_ref = col.document(f"id-{index}") + bw.create(doc_ref, {"index": index}) + cleanup(doc_ref.delete) + + bw.close() + assert bw._total_batches_sent >= 3 # retries could lead to more than 3 batches + assert bw._total_write_operations >= 50 # same retries rule applies again + assert bw._in_flight_documents == 0 + assert len(bw._operations) == 0 + + # And now assert that the documents were in fact written to the database + assert len(col.get()) == 50 + + def test_watch_document(client, cleanup): db = client collection_ref = db.collection("wd-users" + UNIQUE_RESOURCE_ID) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index ef8022f0e7..a4db4e75ff 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1026,6 +1026,29 @@ async def test_get_all(client, cleanup): check_snapshot(snapshot3, document3, restricted3, write_result3) +async def test_live_bulk_writer(client, cleanup): + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + db: AsyncClient = client + bw: BulkWriter = db.bulk_writer() + col = db.collection(f"bulkitems-async{UNIQUE_RESOURCE_ID}") + + for index in range(50): + doc_ref = col.document(f"id-{index}") + bw.create(doc_ref, {"index": index}) + cleanup(doc_ref.delete) + + bw.close() + assert bw._total_batches_sent >= 3 # retries could lead to more than 3 batches + assert bw._total_write_operations >= 50 # same retries rule applies again + assert bw._in_flight_documents == 0 + assert len(bw._operations) == 0 + + # And now assert that the documents were in fact written to the database + assert len(await col.get()) == 50 + + async def test_batch(client, cleanup): collection_name = "batch" + UNIQUE_RESOURCE_ID diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 65aece0d4d..92d20b7ece 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import datetime import mock import typing @@ -82,3 +83,23 @@ def build_document_snapshot( create_time=create_time or build_timestamp(), update_time=update_time or build_timestamp(), ) + + +class FakeThreadPoolExecutor: + def __init__(self, *args, **kwargs): + self._shutdown = False + + def submit(self, callable) -> typing.NoReturn: + if self._shutdown: + raise RuntimeError( + "cannot schedule new futures after shutdown" + ) # pragma: NO COVER + future = concurrent.futures.Future() + future.set_result(callable()) + return future + + def shutdown(self): + self._shutdown = True + + def __repr__(self): + return f"FakeThreadPoolExecutor(shutdown={self._shutdown})" diff --git a/tests/unit/v1/test_async_batch.py b/tests/unit/v1/test_async_batch.py index dce1cefdf7..39f0d53914 100644 --- a/tests/unit/v1/test_async_batch.py +++ b/tests/unit/v1/test_async_batch.py @@ -20,6 +20,8 @@ class TestAsyncWriteBatch(aiounittest.AsyncTestCase): + """Tests the AsyncWriteBatch.commit method""" + @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_batch import AsyncWriteBatch diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index b766c22fcf..bb7a51dd83 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -373,6 +373,21 @@ async def test_get_all_unknown_result(self): metadata=client._rpc_metadata, ) + def test_bulk_writer(self): + """BulkWriter is opaquely async and thus does not have a dedicated + async variant.""" + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + client = self._make_default_one() + bulk_writer = client.bulk_writer() + self.assertIsInstance(bulk_writer, BulkWriter) + self.assertIs(bulk_writer._client, client._sync_copy) + + def test_sync_copy(self): + client = self._make_default_one() + # Multiple calls to this method should return the same cached instance. + self.assertIs(client._to_sync_copy(), client._to_sync_copy()) + def test_batch(self): from google.cloud.firestore_v1.async_batch import AsyncWriteBatch diff --git a/tests/unit/v1/test_base_batch.py b/tests/unit/v1/test_base_batch.py index affe0e1395..6bdb0da073 100644 --- a/tests/unit/v1/test_base_batch.py +++ b/tests/unit/v1/test_base_batch.py @@ -13,16 +13,26 @@ # limitations under the License. import unittest +from google.cloud.firestore_v1.base_batch import BaseWriteBatch import mock +class TestableBaseWriteBatch(BaseWriteBatch): + def __init__(self, client): + super().__init__(client=client) + + """Create a fake subclass of `BaseWriteBatch` for the purposes of + evaluating the shared methods.""" + + def commit(self): + pass # pragma: NO COVER + + class TestBaseWriteBatch(unittest.TestCase): @staticmethod def _get_target_class(): - from google.cloud.firestore_v1.base_batch import BaseWriteBatch - - return BaseWriteBatch + return TestableBaseWriteBatch def _make_one(self, *args, **kwargs): klass = self._get_target_class() diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index 119942fc34..3e3bef1ad8 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -18,6 +18,8 @@ class TestWriteBatch(unittest.TestCase): + """Tests the WriteBatch.commit method""" + @staticmethod def _get_target_class(): from google.cloud.firestore_v1.batch import WriteBatch @@ -61,6 +63,7 @@ def _commit_helper(self, retry=None, timeout=None): batch.create(document1, {"ten": 10, "buck": "ets"}) document2 = client.document("c", "d", "e", "f") batch.delete(document2) + self.assertEqual(len(batch), 2) write_pbs = batch._write_pbs[::] write_results = batch.commit(**kwargs) diff --git a/tests/unit/v1/test_bulk_batch.py b/tests/unit/v1/test_bulk_batch.py new file mode 100644 index 0000000000..20d43b9ccc --- /dev/null +++ b/tests/unit/v1/test_bulk_batch.py @@ -0,0 +1,105 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock + + +class TestBulkWriteBatch(unittest.TestCase): + """Tests the BulkWriteBatch.commit method""" + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + return BulkWriteBatch + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + batch = self._make_one(mock.sentinel.client) + self.assertIs(batch._client, mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + self.assertIsNone(batch.write_results) + + def _write_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.Mock(spec=["batch_write"]) + write_response = firestore.BatchWriteResponse( + write_results=[write.WriteResult(), write.WriteResult()], + ) + firestore_api.batch_write.return_value = write_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = self._make_one(client) + document1 = client.document("a", "b") + self.assertFalse(document1 in batch) + batch.create(document1, {"ten": 10, "buck": "ets"}) + self.assertTrue(document1 in batch) + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + write_pbs = batch._write_pbs[::] + + resp = batch.commit(**kwargs) + self.assertEqual(resp.write_results, list(write_response.write_results)) + self.assertEqual(batch.write_results, resp.write_results) + # Make sure batch has no more "changes". + self.assertEqual(batch._write_pbs, []) + + # Verify the mocks. + firestore_api.batch_write.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "labels": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + def test_write(self): + self._write_helper() + + def test_write_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + self._write_helper(retry=retry, timeout=timeout) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="seventy-nine"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_bulk_writer.py b/tests/unit/v1/test_bulk_writer.py new file mode 100644 index 0000000000..685d48a525 --- /dev/null +++ b/tests/unit/v1/test_bulk_writer.py @@ -0,0 +1,600 @@ +# # Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest +from typing import List, NoReturn, Optional, Tuple, Type + +from google.rpc import status_pb2 +import aiounittest # type: ignore + +from google.cloud.firestore_v1._helpers import build_timestamp, ExistsOption +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.base_client import BaseClient +from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch +from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriter, + BulkWriteFailure, + BulkWriterCreateOperation, + BulkWriterOptions, + BulkWriterOperation, + OperationRetry, + SendMode, +) +from google.cloud.firestore_v1.types.firestore import BatchWriteResponse +from google.cloud.firestore_v1.types.write import WriteResult +from tests.unit.v1._test_helpers import FakeThreadPoolExecutor + + +class NoSendBulkWriter(BulkWriter): + """Test-friendly BulkWriter subclass whose `_send` method returns faked + BatchWriteResponse instances and whose _process_response` method stores + those faked instances for later evaluation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._responses: List[ + Tuple[BulkWriteBatch, BatchWriteResponse, BulkWriterOperation] + ] = [] + self._fail_indices: List[int] = [] + + def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: + """Generate a fake `BatchWriteResponse` for the supplied batch instead + of actually submitting it to the server. + """ + return BatchWriteResponse( + write_results=[ + WriteResult(update_time=build_timestamp()) + if index not in self._fail_indices + else WriteResult() + for index, el in enumerate(batch._document_references.values()) + ], + status=[ + status_pb2.Status(code=0 if index not in self._fail_indices else 1) + for index, el in enumerate(batch._document_references.values()) + ], + ) + + def _process_response( + self, + batch: BulkWriteBatch, + response: BatchWriteResponse, + operations: List[BulkWriterOperation], + ) -> NoReturn: + super()._process_response(batch, response, operations) + self._responses.append((batch, response, operations)) + + def _instantiate_executor(self): + return FakeThreadPoolExecutor() + + +class _SyncClientMixin: + """Mixin which helps a `_BaseBulkWriterTests` subclass simulate usage of + synchronous Clients, Collections, DocumentReferences, etc.""" + + def _get_client_class(self) -> Type: + return Client + + +class _AsyncClientMixin: + """Mixin which helps a `_BaseBulkWriterTests` subclass simulate usage of + AsyncClients, AsyncCollections, AsyncDocumentReferences, etc.""" + + def _get_client_class(self) -> Type: + return AsyncClient + + +class _BaseBulkWriterTests: + def setUp(self): + self.client: BaseClient = self._get_client_class()() + + def _get_document_reference( + self, collection_name: Optional[str] = "col", id: Optional[str] = None, + ) -> Type: + return self.client.collection(collection_name).document(id) + + def _doc_iter(self, num: int, ids: Optional[List[str]] = None): + for _ in range(num): + id: Optional[str] = ids[_] if ids else None + yield self._get_document_reference(id=id), {"id": _} + + def _verify_bw_activity(self, bw: BulkWriter, counts: List[Tuple[int, int]]): + """ + Args: + bw: (BulkWriter) + The BulkWriter instance to inspect. + counts: (tuple) A sequence of integer pairs, with 0-index integers + representing the size of sent batches, and 1-index integers + representing the number of times batches of that size should + have been sent. + """ + total_batches = sum([el[1] for el in counts]) + batches_word = "batches" if total_batches != 1 else "batch" + self.assertEqual( + len(bw._responses), + total_batches, + f"Expected to have sent {total_batches} {batches_word}, but only sent {len(bw._responses)}", + ) + docs_count = {} + resp: BatchWriteResponse + for _, resp, ops in bw._responses: + docs_count.setdefault(len(resp.write_results), 0) + docs_count[len(resp.write_results)] += 1 + + self.assertEqual(len(docs_count), len(counts)) + for size, num_sent in counts: + self.assertEqual(docs_count[size], num_sent) + + # Assert flush leaves no operation behind + self.assertEqual(len(bw._operations), 0) + + def test_create_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(101): + bw.create(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_delete_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, _ in self._doc_iter(101): + bw.delete(ref) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_delete_separates_batch(self): + bw = NoSendBulkWriter(self.client) + ref = self._get_document_reference(id="asdf") + bw.create(ref, {}) + bw.delete(ref) + bw.flush() + # Consecutive batches each with 1 operation should have been sent + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_set_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(101): + bw.set(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_update_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(101): + bw.update(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_update_separates_batch(self): + bw = NoSendBulkWriter(self.client) + ref = self._get_document_reference(id="asdf") + bw.create(ref, {}) + bw.update(ref, {"field": "value"}) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_invokes_success_callbacks_successfully(self): + bw = NoSendBulkWriter(self.client) + bw._fail_indices = [] + bw._sent_batches = 0 + bw._sent_documents = 0 + + def _on_batch(batch, response, bulk_writer): + assert isinstance(batch, BulkWriteBatch) + assert isinstance(response, BatchWriteResponse) + assert isinstance(bulk_writer, BulkWriter) + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + assert isinstance(ref, BaseDocumentReference) + assert isinstance(result, WriteResult) + assert isinstance(bulk_writer, BulkWriter) + bulk_writer._sent_documents += 1 + + bw.on_write_result(_on_write) + bw.on_batch_result(_on_batch) + + for ref, data in self._doc_iter(101): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._sent_batches, 6) + self.assertEqual(bw._sent_documents, 101) + self.assertEqual(len(bw._operations), 0) + + def test_invokes_error_callbacks_successfully(self): + bw = NoSendBulkWriter(self.client) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._sent_batches = 0 + bw._sent_documents = 0 + bw._total_retries = 0 + + times_to_retry = 1 + + def _on_batch(batch, response, bulk_writer): + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + bulk_writer._sent_documents += 1 # pragma: NO COVER + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_batch_result(_on_batch) + bw.on_write_result(_on_write) + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._sent_documents, 0) + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(bw._sent_batches, 2) + self.assertEqual(len(bw._operations), 0) + + def test_invokes_error_callbacks_successfully_multiple_retries(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._sent_batches = 0 + bw._sent_documents = 0 + bw._total_retries = 0 + + times_to_retry = 10 + + def _on_batch(batch, response, bulk_writer): + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + bulk_writer._sent_documents += 1 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_batch_result(_on_batch) + bw.on_write_result(_on_write) + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(2): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._sent_documents, 1) + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(bw._sent_batches, times_to_retry + 1) + self.assertEqual(len(bw._operations), 0) + + def test_default_error_handler(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + bw._attempts = 0 + + def _on_error(error, bw): + bw._attempts = error.attempts + return bw._default_on_error(error, bw) + + bw.on_write_error(_on_error) + + # First document in each batch will "fail" + bw._fail_indices = [0] + for ref, data in self._doc_iter(1): + bw.create(ref, data) + bw.flush() + self.assertEqual(bw._attempts, 15) + + def test_handles_errors_and_successes_correctly(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._sent_batches = 0 + bw._sent_documents = 0 + bw._total_retries = 0 + + times_to_retry = 1 + + def _on_batch(batch, response, bulk_writer): + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + bulk_writer._sent_documents += 1 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_batch_result(_on_batch) + bw.on_write_result(_on_write) + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(40): + bw.create(ref, data) + bw.flush() + + # 19 successful writes per batch + self.assertEqual(bw._sent_documents, 38) + self.assertEqual(bw._total_retries, times_to_retry * 2) + self.assertEqual(bw._sent_batches, 4) + self.assertEqual(len(bw._operations), 0) + + def test_create_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_delete_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, _ in self._doc_iter(1): + bw.delete(ref) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_set_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.set(ref, data) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_update_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.update(ref, data) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_serial_calls_send_correctly(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(mode=SendMode.serial) + ) + for ref, data in self._doc_iter(101): + bw.create(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_separates_same_document(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(2, ["same-id", "same-id"]): + bw.create(ref, data) + bw.flush() + # Seeing the same document twice should lead to separate batches + # Expect to have sent 1-item batches twice. + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_separates_same_document_different_operation(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(1, ["same-id"]): + bw.create(ref, data) + bw.set(ref, data) + bw.flush() + # Seeing the same document twice should lead to separate batches. + # Expect to have sent 1-item batches twice. + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_ensure_sending_repeatedly_callable(self): + bw = NoSendBulkWriter(self.client) + bw._is_sending = True + bw._ensure_sending() + + def test_flush_close_repeatedly_callable(self): + bw = NoSendBulkWriter(self.client) + bw.flush() + bw.flush() + bw.close() + + def test_flush_sends_in_progress(self): + bw = NoSendBulkWriter(self.client) + bw.create(self._get_document_reference(), {"whatever": "you want"}) + bw.flush() + self._verify_bw_activity(bw, [(1, 1,)]) + + def test_flush_sends_all_queued_batches(self): + bw = NoSendBulkWriter(self.client) + for _ in range(2): + bw.create(self._get_document_reference(), {"whatever": "you want"}) + bw._queued_batches.append(bw._operations) + bw._reset_operations() + bw.flush() + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_cannot_add_after_close(self): + bw = NoSendBulkWriter(self.client) + bw.close() + self.assertRaises(Exception, bw._verify_not_closed) + + def test_multiple_flushes(self): + bw = NoSendBulkWriter(self.client) + bw.flush() + bw.flush() + + def test_update_raises_with_bad_option(self): + bw = NoSendBulkWriter(self.client) + self.assertRaises( + ValueError, + bw.update, + self._get_document_reference("id"), + {}, + option=ExistsOption(exists=True), + ) + + +class TestSyncBulkWriter(_SyncClientMixin, _BaseBulkWriterTests, unittest.TestCase): + """All BulkWriters are opaquely async, but this one simulates a BulkWriter + dealing with synchronous DocumentReferences.""" + + +class TestAsyncBulkWriter( + _AsyncClientMixin, _BaseBulkWriterTests, aiounittest.AsyncTestCase +): + """All BulkWriters are opaquely async, but this one simulates a BulkWriter + dealing with AsyncDocumentReferences.""" + + +class TestScheduling(unittest.TestCase): + def test_max_in_flight_honored(self): + bw = NoSendBulkWriter(Client()) + # Calling this method sets up all the internal timekeeping machinery + bw._rate_limiter.take_tokens(20) + + # Now we pretend that all tokens have been consumed. This will force us + # to wait actual, real world milliseconds before being cleared to send more + bw._rate_limiter._available_tokens = 0 + + st = datetime.datetime.now() + + # Make a real request, subject to the actual real world clock. + # As this request is 1/10th the per second limit, we should wait ~100ms + bw._request_send(50) + + self.assertGreater( + datetime.datetime.now() - st, datetime.timedelta(milliseconds=90), + ) + + def test_operation_retry_scheduling(self): + now = datetime.datetime.now() + one_second_from_now = now + datetime.timedelta(seconds=1) + + db = Client() + operation = BulkWriterCreateOperation( + reference=db.collection("asdf").document("asdf"), + document_data={"does.not": "matter"}, + ) + operation2 = BulkWriterCreateOperation( + reference=db.collection("different").document("document"), + document_data={"different": "values"}, + ) + + op1 = OperationRetry(operation=operation, run_at=now) + op2 = OperationRetry(operation=operation2, run_at=now) + op3 = OperationRetry(operation=operation, run_at=one_second_from_now) + + self.assertLess(op1, op3) + self.assertLess(op1, op3.run_at) + self.assertLess(op2, op3) + self.assertLess(op2, op3.run_at) + + # Because these have the same values for `run_at`, neither should conclude + # they are less than the other. It is okay that if we checked them with + # greater-than evaluation, they would return True (because + # @functools.total_ordering flips the result from __lt__). In practice, + # this only arises for actual ties, and we don't care how actual ties are + # ordered as we maintain the sorted list of scheduled retries. + self.assertFalse(op1 < op2) + self.assertFalse(op2 < op1) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 0055dab2ca..a46839ac59 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -369,6 +369,14 @@ def test_batch(self): self.assertIs(batch._client, client) self.assertEqual(batch._write_pbs, []) + def test_bulk_writer(self): + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + client = self._make_default_one() + bulk_writer = client.bulk_writer() + self.assertIsInstance(bulk_writer, BulkWriter) + self.assertIs(bulk_writer._client, client) + def test_transaction(self): from google.cloud.firestore_v1.transaction import Transaction diff --git a/tests/unit/v1/test_rate_limiter.py b/tests/unit/v1/test_rate_limiter.py new file mode 100644 index 0000000000..ea41905e49 --- /dev/null +++ b/tests/unit/v1/test_rate_limiter.py @@ -0,0 +1,200 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest +from typing import Optional + +import mock +import google +from google.cloud.firestore_v1 import rate_limiter + + +# Pick a point in time as the center of our universe for this test run. +# It is okay for this to update every time the tests are run. +fake_now = datetime.datetime.utcnow() + + +def now_plus_n( + seconds: Optional[int] = 0, microseconds: Optional[int] = 0, +) -> datetime.timedelta: + return fake_now + datetime.timedelta(seconds=seconds, microseconds=microseconds,) + + +class TestRateLimiter(unittest.TestCase): + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_basic(self, mocked_now): + """Verifies that if the clock does not advance, the RateLimiter allows 500 + writes before crashing out. + """ + mocked_now.return_value = fake_now + # This RateLimiter will never advance. Poor fella. + ramp = rate_limiter.RateLimiter() + for _ in range(rate_limiter.default_initial_tokens): + self.assertEqual(ramp.take_tokens(), 1) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_with_refill(self, mocked_now): + """Verifies that if the clock advances, the RateLimiter allows appropriate + additional writes. + """ + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 0 + self.assertEqual(ramp.take_tokens(), 0) + # Advance the clock 0.1 seconds + mocked_now.return_value = now_plus_n(microseconds=100000) + for _ in range(round(rate_limiter.default_initial_tokens / 10)): + self.assertEqual(ramp.take_tokens(), 1) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_phase_length(self, mocked_now): + """Verifies that if the clock advances, the RateLimiter allows appropriate + additional writes. + """ + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + self.assertEqual(ramp.take_tokens(), 1) + ramp._available_tokens = 0 + self.assertEqual(ramp.take_tokens(), 0) + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + for _ in range(round(rate_limiter.default_initial_tokens * 3 / 2)): + self.assertTrue( + ramp.take_tokens(), msg=f"token {_} should have been allowed" + ) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_idle_phase_length(self, mocked_now): + """Verifies that if the clock advances but nothing happens, the RateLimiter + doesn't ramp up. + """ + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 0 + self.assertEqual(ramp.take_tokens(), 0) + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + for _ in range(round(rate_limiter.default_initial_tokens)): + self.assertEqual( + ramp.take_tokens(), 1, msg=f"token {_} should have been allowed" + ) + self.assertEqual(ramp._maximum_tokens, 500) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_take_batch_size(self, mocked_now): + """Verifies that if the clock advances but nothing happens, the RateLimiter + doesn't ramp up. + """ + page_size: int = 20 + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 15 + self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 15) + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp._check_phase() + self.assertEqual(ramp._maximum_tokens, 750) + + for _ in range(740 // page_size): + self.assertEqual( + ramp.take_tokens(page_size), + page_size, + msg=f"page {_} should have been allowed", + ) + self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 10) + self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_phase_progress(self, mocked_now): + mocked_now.return_value = fake_now + + ramp = rate_limiter.RateLimiter() + self.assertEqual(ramp._phase, 0) + self.assertEqual(ramp._maximum_tokens, 500) + ramp.take_tokens() + + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 1) + self.assertEqual(ramp._maximum_tokens, 750) + + # Advance the clock another phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 1125) + + # Advance the clock another ms and the phase should not advance + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=2, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 1125) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_global_max_tokens(self, mocked_now): + mocked_now.return_value = fake_now + + ramp = rate_limiter.RateLimiter(global_max_tokens=499,) + self.assertEqual(ramp._phase, 0) + self.assertEqual(ramp._maximum_tokens, 499) + ramp.take_tokens() + + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 1) + self.assertEqual(ramp._maximum_tokens, 499) + + # Advance the clock another phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 499) + + # Advance the clock another ms and the phase should not advance + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=2, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 499) + + def test_utcnow(self): + self.assertTrue( + isinstance( + google.cloud.firestore_v1.rate_limiter.utcnow(), datetime.datetime, + ) + ) From 639d55277c9c35b99be36921330b7b7db2fd6e4c Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Wed, 11 Aug 2021 14:23:21 -0400 Subject: [PATCH 11/19] chore: avoid `.nox` directories when building docs (#419) Source-Link: https://github.com/googleapis/synthtool/commit/7e1f6da50524b5d98eb67adbf6dd0805df54233d Post-Processor: gcr.io/repo-automation-bots/owlbot-python:latest@sha256:a1a891041baa4ffbe1a809ac1b8b9b4a71887293c9101c88e8e255943c5aec2d Co-authored-by: Owl Bot --- .github/.OwlBot.lock.yaml | 2 +- docs/conf.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index 9ee60f7e48..b771c37cae 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -1,3 +1,3 @@ docker: image: gcr.io/repo-automation-bots/owlbot-python:latest - digest: sha256:aea14a583128771ae8aefa364e1652f3c56070168ef31beb203534222d842b8b + digest: sha256:a1a891041baa4ffbe1a809ac1b8b9b4a71887293c9101c88e8e255943c5aec2d diff --git a/docs/conf.py b/docs/conf.py index a7bb6eb61e..df14f3d5b1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,6 +110,7 @@ # directories to ignore when looking for source files. exclude_patterns = [ "_build", + "**/.nox/**/*", "samples/AUTHORING_GUIDE.md", "samples/CONTRIBUTING.md", "samples/snippets/README.rst", From a1e9a162d5625eed8e79f31598541bbe40ca853f Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 11 Aug 2021 22:16:28 -0400 Subject: [PATCH 12/19] tests: allow prerelease deps on Python 3.9 (#415) Closes #414. --- testing/constraints-3.9.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index e69de29bb2..6d34489a53 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -0,0 +1,2 @@ +# Allow prerelease requirements +--pre From 539c1d719191eb0ae3a49290c26b628de7c27cd5 Mon Sep 17 00:00:00 2001 From: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Date: Thu, 12 Aug 2021 15:24:19 -0600 Subject: [PATCH 13/19] fix: remove unused requirement pytz (#422) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: remove unused requirement pytz * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md Co-authored-by: Owl Bot --- setup.py | 1 - tests/unit/v1/test_watch.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 5e913edcf2..50de98e267 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,6 @@ # https://github.com/googleapis/google-cloud-python/issues/10566 "google-cloud-core >= 1.4.1, <3.0.0dev", "packaging >= 14.3", - "pytz", "proto-plus >= 1.10.0", ] extras = {} diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 759549b72a..c5b758459f 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -566,9 +566,9 @@ def test_on_snapshot_unknown_listen_type(self): ) def test_push_callback_called_no_changes(self): - import pytz - - dummy_time = (datetime.datetime.fromtimestamp(1534858278, pytz.utc),) + dummy_time = ( + datetime.datetime.fromtimestamp(1534858278, datetime.timezone.utc), + ) inst = self._makeOne() inst.push(dummy_time, "token") From 0923c955090191a7ac4dce032562403396fd39ad Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Fri, 13 Aug 2021 11:35:26 -0400 Subject: [PATCH 14/19] chore: drop mention of Python 2.7 from templates (#423) Source-Link: https://github.com/googleapis/synthtool/commit/facee4cc1ea096cd8bcc008bb85929daa7c414c0 Post-Processor: gcr.io/repo-automation-bots/owlbot-python:latest@sha256:9743664022bd63a8084be67f144898314c7ca12f0a03e422ac17c733c129d803 Co-authored-by: Owl Bot --- .github/.OwlBot.lock.yaml | 2 +- noxfile.py | 11 ++++++++--- scripts/readme-gen/templates/install_deps.tmpl.rst | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index b771c37cae..a9fcd07cc4 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -1,3 +1,3 @@ docker: image: gcr.io/repo-automation-bots/owlbot-python:latest - digest: sha256:a1a891041baa4ffbe1a809ac1b8b9b4a71887293c9101c88e8e255943c5aec2d + digest: sha256:9743664022bd63a8084be67f144898314c7ca12f0a03e422ac17c733c129d803 diff --git a/noxfile.py b/noxfile.py index ff4bb10c4c..0e6354ceea 100644 --- a/noxfile.py +++ b/noxfile.py @@ -93,11 +93,16 @@ def default(session): constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" ) - session.install("asyncmock", "pytest-asyncio", "-c", constraints_path) - session.install( - "mock", "pytest", "pytest-cov", "aiounittest", "-c", constraints_path + "mock", + "asyncmock", + "pytest", + "pytest-cov", + "pytest-asyncio", + "-c", + constraints_path, ) + session.install("aiounittest", "-c", constraints_path) session.install("-e", ".", "-c", constraints_path) diff --git a/scripts/readme-gen/templates/install_deps.tmpl.rst b/scripts/readme-gen/templates/install_deps.tmpl.rst index a0406dba8c..275d649890 100644 --- a/scripts/readme-gen/templates/install_deps.tmpl.rst +++ b/scripts/readme-gen/templates/install_deps.tmpl.rst @@ -12,7 +12,7 @@ Install Dependencies .. _Python Development Environment Setup Guide: https://cloud.google.com/python/setup -#. Create a virtualenv. Samples are compatible with Python 2.7 and 3.4+. +#. Create a virtualenv. Samples are compatible with Python 3.6+. .. code-block:: bash From 813a57b1070a1f6ac41d02897fab33f8039b83f9 Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Mon, 16 Aug 2021 15:11:56 -0700 Subject: [PATCH 15/19] feat: add recursive delete (#420) * feat: add recursive delete * made chunkify private Co-authored-by: Christopher Wilcox --- google/cloud/firestore_v1/async_client.py | 84 ++++- google/cloud/firestore_v1/async_collection.py | 4 + google/cloud/firestore_v1/async_query.py | 44 ++- google/cloud/firestore_v1/base_client.py | 13 +- google/cloud/firestore_v1/base_document.py | 4 +- google/cloud/firestore_v1/base_query.py | 6 + google/cloud/firestore_v1/client.py | 88 +++++- google/cloud/firestore_v1/collection.py | 3 + google/cloud/firestore_v1/query.py | 44 ++- tests/system/test_system.py | 274 +++++++++++++---- tests/system/test_system_async.py | 287 +++++++++++++----- tests/unit/v1/test_async_client.py | 106 +++++++ tests/unit/v1/test_async_collection.py | 38 +++ tests/unit/v1/test_async_query.py | 24 ++ tests/unit/v1/test_client.py | 100 ++++++ tests/unit/v1/test_collection.py | 37 +++ tests/unit/v1/test_query.py | 36 +++ 17 files changed, 1046 insertions(+), 146 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 68cb676f2a..a4be110020 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -43,13 +43,17 @@ DocumentSnapshot, ) from google.cloud.firestore_v1.async_transaction import AsyncTransaction +from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, ) from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) -from typing import Any, AsyncGenerator, Iterable, List +from typing import Any, AsyncGenerator, Iterable, List, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER class AsyncClient(BaseClient): @@ -300,6 +304,84 @@ async def collections( async for collection_id in iterator: yield self.collection(collection_id) + async def recursive_delete( + self, + reference: Union[AsyncCollectionReference, AsyncDocumentReference], + *, + bulk_writer: Optional["BulkWriter"] = None, + chunk_size: Optional[int] = 5000, + ): + """Deletes documents and their subcollections, regardless of collection + name. + + Passing an AsyncCollectionReference leads to each document in the + collection getting deleted, as well as all of their descendents. + + Passing an AsyncDocumentReference deletes that one document and all of + its descendents. + + Args: + reference (Union[ + :class:`@google.cloud.firestore_v1.async_collection.CollectionReference`, + :class:`@google.cloud.firestore_v1.async_document.DocumentReference`, + ]) + The reference to be deleted. + + bulk_writer (Optional[:class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter`]) + The BulkWriter used to delete all matching documents. Supply this + if you want to override the default throttling behavior. + """ + return await self._recursive_delete( + reference, bulk_writer=bulk_writer, chunk_size=chunk_size, + ) + + async def _recursive_delete( + self, + reference: Union[AsyncCollectionReference, AsyncDocumentReference], + *, + bulk_writer: Optional["BulkWriter"] = None, # type: ignore + chunk_size: Optional[int] = 5000, + depth: Optional[int] = 0, + ) -> int: + """Recursion helper for `recursive_delete.""" + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + bulk_writer = bulk_writer or BulkWriter() + + num_deleted: int = 0 + + if isinstance(reference, AsyncCollectionReference): + chunk: List[DocumentSnapshot] + async for chunk in reference.recursive().select( + [FieldPath.document_id()] + )._chunkify(chunk_size): + doc_snap: DocumentSnapshot + for doc_snap in chunk: + num_deleted += 1 + bulk_writer.delete(doc_snap.reference) + + elif isinstance(reference, AsyncDocumentReference): + col_ref: AsyncCollectionReference + async for col_ref in reference.collections(): + num_deleted += await self._recursive_delete( + col_ref, + bulk_writer=bulk_writer, + depth=depth + 1, + chunk_size=chunk_size, + ) + num_deleted += 1 + bulk_writer.delete(reference) + + else: + raise TypeError( + f"Unexpected type for reference: {reference.__class__.__name__}" + ) + + if depth == 0: + bulk_writer.close() + + return num_deleted + def batch(self) -> AsyncWriteBatch: """Get a batch instance from this client. diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index ca4ec8b0ff..d064051271 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -72,6 +72,10 @@ def _query(self) -> async_query.AsyncQuery: """ return async_query.AsyncQuery(self) + async def _chunkify(self, chunk_size: int): + async for page in self._query()._chunkify(chunk_size): + yield page + async def add( self, document_data: dict, diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 2f94b5f7c9..0444b92bc7 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -33,7 +33,8 @@ ) from google.cloud.firestore_v1 import async_document -from typing import AsyncGenerator, Type +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from typing import AsyncGenerator, List, Optional, Type # Types needed only for Type Hints from google.cloud.firestore_v1.transaction import Transaction @@ -126,6 +127,47 @@ def __init__( recursive=recursive, ) + async def _chunkify( + self, chunk_size: int + ) -> AsyncGenerator[List[DocumentSnapshot], None]: + # Catch the edge case where a developer writes the following: + # `my_query.limit(500)._chunkify(1000)`, which ultimately nullifies any + # need to yield chunks. + if self._limit and chunk_size > self._limit: + yield await self.get() + return + + max_to_return: Optional[int] = self._limit + num_returned: int = 0 + original: AsyncQuery = self._copy() + last_document: Optional[DocumentSnapshot] = None + + while True: + # Optionally trim the `chunk_size` down to honor a previously + # applied limit as set by `self.limit()` + _chunk_size: int = original._resolve_chunk_size(num_returned, chunk_size) + + # Apply the optionally pruned limit and the cursor, if we are past + # the first page. + _q = original.limit(_chunk_size) + if last_document: + _q = _q.start_after(last_document) + + snapshots = await _q.get() + last_document = snapshots[-1] + num_returned += len(snapshots) + + yield snapshots + + # Terminate the iterator if we have reached either of two end + # conditions: + # 1. There are no more documents, or + # 2. We have reached the desired overall limit + if len(snapshots) < _chunk_size or ( + max_to_return and num_returned >= max_to_return + ): + return + async def get( self, transaction: Transaction = None, diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index e68031ed4d..17068a9740 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,11 +37,9 @@ from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.cloud.firestore_v1.bulk_writer import ( - BulkWriter, - BulkWriterOptions, -) + from google.cloud.firestore_v1.field_path import render_field_path +from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from typing import ( Any, AsyncGenerator, @@ -312,6 +310,13 @@ def _document_path_helper(self, *document_path) -> List[str]: joined_path = joined_path[len(base_path) :] return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) + def recursive_delete( + self, + reference: Union[BaseCollectionReference, BaseDocumentReference], + bulk_writer: Optional["BulkWriter"] = None, # type: ignore + ) -> int: + raise NotImplementedError + @staticmethod def field_path(*field_names: str) -> str: """Create a **field path** from a list of nested field names. diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 32694ac472..9e15b108c2 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -315,10 +315,10 @@ def _prep_collections( def collections( self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: + ) -> None: raise NotImplementedError - def on_snapshot(self, callback) -> NoReturn: + def on_snapshot(self, callback) -> None: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 1812cfca00..4f3ee101ff 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -424,6 +424,12 @@ def limit_to_last(self, count: int) -> "BaseQuery": """ return self._copy(limit=count, limit_to_last=True) + def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int: + """Utility function for chunkify.""" + if self._limit is not None and (num_loaded + chunk_size) > self._limit: + return max(self._limit - num_loaded, 0) + return chunk_size + def offset(self, num_to_skip: int) -> "BaseQuery": """Skip to an offset in a query. diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 20ef5055f3..750acb0beb 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -39,17 +39,22 @@ from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1.services.firestore import client as firestore_client from google.cloud.firestore_v1.services.firestore.transports import ( grpc as firestore_grpc_transport, ) -from typing import Any, Generator, Iterable +from typing import Any, Generator, Iterable, List, Optional, Union, TYPE_CHECKING # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +if TYPE_CHECKING: + from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER + + class Client(BaseClient): """Client for interacting with Google Cloud Firestore API. @@ -286,6 +291,87 @@ def collections( for collection_id in iterator: yield self.collection(collection_id) + def recursive_delete( + self, + reference: Union[CollectionReference, DocumentReference], + *, + bulk_writer: Optional["BulkWriter"] = None, + chunk_size: Optional[int] = 5000, + ) -> int: + """Deletes documents and their subcollections, regardless of collection + name. + + Passing a CollectionReference leads to each document in the collection + getting deleted, as well as all of their descendents. + + Passing a DocumentReference deletes that one document and all of its + descendents. + + Args: + reference (Union[ + :class:`@google.cloud.firestore_v1.collection.CollectionReference`, + :class:`@google.cloud.firestore_v1.document.DocumentReference`, + ]) + The reference to be deleted. + + bulk_writer (Optional[:class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter`]) + The BulkWriter used to delete all matching documents. Supply this + if you want to override the default throttling behavior. + + """ + return self._recursive_delete( + reference, bulk_writer=bulk_writer, chunk_size=chunk_size, + ) + + def _recursive_delete( + self, + reference: Union[CollectionReference, DocumentReference], + *, + bulk_writer: Optional["BulkWriter"] = None, + chunk_size: Optional[int] = 5000, + depth: Optional[int] = 0, + ) -> int: + """Recursion helper for `recursive_delete.""" + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + bulk_writer = bulk_writer or BulkWriter() + + num_deleted: int = 0 + + if isinstance(reference, CollectionReference): + chunk: List[DocumentSnapshot] + for chunk in ( + reference.recursive() + .select([FieldPath.document_id()]) + ._chunkify(chunk_size) + ): + doc_snap: DocumentSnapshot + for doc_snap in chunk: + num_deleted += 1 + bulk_writer.delete(doc_snap.reference) + + elif isinstance(reference, DocumentReference): + col_ref: CollectionReference + for col_ref in reference.collections(): + num_deleted += self._recursive_delete( + col_ref, + bulk_writer=bulk_writer, + chunk_size=chunk_size, + depth=depth + 1, + ) + num_deleted += 1 + bulk_writer.delete(reference) + + else: + raise TypeError( + f"Unexpected type for reference: {reference.__class__.__name__}" + ) + + if depth == 0: + bulk_writer.close() + + return num_deleted + def batch(self) -> WriteBatch: """Get a batch instance from this client. diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 96d076e2c4..643e2d7ef1 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -137,6 +137,9 @@ def list_documents( ) return (_item_to_document_ref(self, i) for i in iterator) + def _chunkify(self, chunk_size: int): + return self._query()._chunkify(chunk_size) + def get( self, transaction: Transaction = None, diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index f1e044cbd1..50c5559b14 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -18,7 +18,6 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ - from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.api_core import gapic_v1 # type: ignore @@ -35,7 +34,7 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Type +from typing import Any, Callable, Generator, List, Optional, Type class Query(BaseQuery): @@ -168,6 +167,47 @@ def get( return list(result) + def _chunkify( + self, chunk_size: int + ) -> Generator[List[DocumentSnapshot], None, None]: + # Catch the edge case where a developer writes the following: + # `my_query.limit(500)._chunkify(1000)`, which ultimately nullifies any + # need to yield chunks. + if self._limit and chunk_size > self._limit: + yield self.get() + return + + max_to_return: Optional[int] = self._limit + num_returned: int = 0 + original: Query = self._copy() + last_document: Optional[DocumentSnapshot] = None + + while True: + # Optionally trim the `chunk_size` down to honor a previously + # applied limits as set by `self.limit()` + _chunk_size: int = original._resolve_chunk_size(num_returned, chunk_size) + + # Apply the optionally pruned limit and the cursor, if we are past + # the first page. + _q = original.limit(_chunk_size) + if last_document: + _q = _q.start_after(last_document) + + snapshots = _q.get() + last_document = snapshots[-1] + num_returned += len(snapshots) + + yield snapshots + + # Terminate the iterator if we have reached either of two end + # conditions: + # 1. There are no more documents, or + # 2. We have reached the desired overall limit + if len(snapshots) < _chunk_size or ( + max_to_return and num_returned >= max_to_return + ): + return + def stream( self, transaction=None, diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 0975a73d09..109029ced2 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -29,6 +29,7 @@ from google.cloud import firestore_v1 as firestore from time import sleep +from typing import Callable, Dict, List, Optional from tests.system.test__helpers import ( FIRESTORE_CREDS, @@ -1235,65 +1236,157 @@ def test_array_union(client, cleanup): assert doc_ref.get().to_dict() == expected -def test_recursive_query(client, cleanup): +def _persist_documents( + client: firestore.Client, + collection_name: str, + documents: List[Dict], + cleanup: Optional[Callable] = None, +): + """Assuming `documents` is a recursive list of dictionaries representing + documents and subcollections, this method writes all of those through + `client.collection(...).document(...).create()`. - philosophers = [ + `documents` must be of this structure: + ```py + documents = [ { - "data": {"name": "Socrates", "favoriteCity": "Athens"}, - "subcollections": { - "pets": [{"name": "Scruffy"}, {"name": "Snowflake"}], - "hobbies": [{"name": "pontificating"}, {"name": "journaling"}], - "philosophers": [{"name": "Aristotle"}, {"name": "Plato"}], - }, + # Required key + "data": , + + # Optional key + "subcollections": , }, - { - "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, - "subcollections": { - "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], - "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], - }, + ... + ] + ``` + """ + for block in documents: + col_ref = client.collection(collection_name) + document_id: str = block["data"]["name"] + doc_ref = col_ref.document(document_id) + doc_ref.set(block["data"]) + if cleanup is not None: + cleanup(doc_ref.delete) + + if "subcollections" in block: + for subcollection_name, inner_blocks in block["subcollections"].items(): + _persist_documents( + client, + f"{collection_name}/{document_id}/{subcollection_name}", + inner_blocks, + ) + + +# documents compatible with `_persist_documents` +philosophers_data_set = [ + { + "data": {"name": "Socrates", "favoriteCity": "Athens"}, + "subcollections": { + "pets": [{"data": {"name": "Scruffy"}}, {"data": {"name": "Snowflake"}}], + "hobbies": [ + {"data": {"name": "pontificating"}}, + {"data": {"name": "journaling"}}, + ], + "philosophers": [ + {"data": {"name": "Aristotle"}}, + {"data": {"name": "Plato"}}, + ], }, - { - "data": {"name": "Plato", "favoriteCity": "Corinth"}, - "subcollections": { - "pets": [{"name": "Cuddles"}, {"name": "Sergeant-Puppers"}], - "hobbies": [{"name": "abstraction"}, {"name": "hypotheticals"}], - }, + }, + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"data": {"name": "Floof-Boy"}}, {"data": {"name": "Doggy-Dog"}}], + "hobbies": [ + {"data": {"name": "questioning-stuff"}}, + {"data": {"name": "meditation"}}, + ], + }, + }, + { + "data": {"name": "Plato", "favoriteCity": "Corinth"}, + "subcollections": { + "pets": [ + {"data": {"name": "Cuddles"}}, + {"data": {"name": "Sergeant-Puppers"}}, + ], + "hobbies": [ + {"data": {"name": "abstraction"}}, + {"data": {"name": "hypotheticals"}}, + ], }, + }, +] + + +def _do_recursive_delete_with_bulk_writer(client, bulk_writer): + philosophers = [philosophers_data_set[0]] + _persist_documents(client, f"philosophers{UNIQUE_RESOURCE_ID}", philosophers) + + doc_paths = [ + "", + "/pets/Scruffy", + "/pets/Snowflake", + "/hobbies/pontificating", + "/hobbies/journaling", + "/philosophers/Aristotle", + "/philosophers/Plato", ] - db = client - collection_ref = db.collection("philosophers") - for philosopher in philosophers: - ref = collection_ref.document( - f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}" - ) - ref.set(philosopher["data"]) - cleanup(ref.delete) - for col_name, entries in philosopher["subcollections"].items(): - sub_col = ref.collection(col_name) - for entry in entries: - inner_doc_ref = sub_col.document(entry["name"]) - inner_doc_ref.set(entry) - cleanup(inner_doc_ref.delete) + # Assert all documents were created so that when they're missing after the + # delete, we're actually testing something. + collection_ref = client.collection(f"philosophers{UNIQUE_RESOURCE_ID}") + for path in doc_paths: + snapshot = collection_ref.document(f"Socrates{path}").get() + assert snapshot.exists, f"Snapshot at Socrates{path} should have been created" + + # Now delete. + num_deleted = client.recursive_delete(collection_ref, bulk_writer=bulk_writer) + assert num_deleted == len(doc_paths) + + # Now they should all be missing + for path in doc_paths: + snapshot = collection_ref.document(f"Socrates{path}").get() + assert ( + not snapshot.exists + ), f"Snapshot at Socrates{path} should have been deleted" + + +def test_recursive_delete_parallelized(client, cleanup): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode + + bw = client.bulk_writer(options=BulkWriterOptions(mode=SendMode.parallel)) + _do_recursive_delete_with_bulk_writer(client, bw) + - ids = [doc.id for doc in db.collection_group("philosophers").recursive().get()] +def test_recursive_delete_serialized(client, cleanup): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode + + bw = client.bulk_writer(options=BulkWriterOptions(mode=SendMode.serial)) + _do_recursive_delete_with_bulk_writer(client, bw) + + +def test_recursive_query(client, cleanup): + col_id: str = f"philosophers-recursive-query{UNIQUE_RESOURCE_ID}" + _persist_documents(client, col_id, philosophers_data_set, cleanup) + + ids = [doc.id for doc in client.collection_group(col_id).recursive().get()] expected_ids = [ # Aristotle doc and subdocs - f"Aristotle{UNIQUE_RESOURCE_ID}", + "Aristotle", "meditation", "questioning-stuff", "Doggy-Dog", "Floof-Boy", # Plato doc and subdocs - f"Plato{UNIQUE_RESOURCE_ID}", + "Plato", "abstraction", "hypotheticals", "Cuddles", "Sergeant-Puppers", # Socrates doc and subdocs - f"Socrates{UNIQUE_RESOURCE_ID}", + "Socrates", "journaling", "pontificating", "Scruffy", @@ -1312,34 +1405,12 @@ def test_recursive_query(client, cleanup): def test_nested_recursive_query(client, cleanup): + col_id: str = f"philosophers-nested-recursive-query{UNIQUE_RESOURCE_ID}" + _persist_documents(client, col_id, philosophers_data_set, cleanup) - philosophers = [ - { - "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, - "subcollections": { - "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], - "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], - }, - }, - ] - - db = client - collection_ref = db.collection("philosophers") - for philosopher in philosophers: - ref = collection_ref.document( - f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}" - ) - ref.set(philosopher["data"]) - cleanup(ref.delete) - for col_name, entries in philosopher["subcollections"].items(): - sub_col = ref.collection(col_name) - for entry in entries: - inner_doc_ref = sub_col.document(entry["name"]) - inner_doc_ref.set(entry) - cleanup(inner_doc_ref.delete) - - aristotle = collection_ref.document(f"Aristotle{UNIQUE_RESOURCE_ID}") - ids = [doc.id for doc in aristotle.collection("pets")._query().recursive().get()] + collection_ref = client.collection(col_id) + aristotle = collection_ref.document("Aristotle") + ids = [doc.id for doc in aristotle.collection("pets").recursive().get()] expected_ids = [ # Aristotle pets @@ -1356,6 +1427,79 @@ def test_nested_recursive_query(client, cleanup): assert ids[index] == expected_ids[index], error_msg +def test_chunked_query(client, cleanup): + col = client.collection(f"chunked-test{UNIQUE_RESOURCE_ID}") + for index in range(10): + doc_ref = col.document(f"document-{index + 1}") + doc_ref.set({"index": index}) + cleanup(doc_ref.delete) + + iter = col._chunkify(3) + assert len(next(iter)) == 3 + assert len(next(iter)) == 3 + assert len(next(iter)) == 3 + assert len(next(iter)) == 1 + + +def test_chunked_query_smaller_limit(client, cleanup): + col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") + for index in range(10): + doc_ref = col.document(f"document-{index + 1}") + doc_ref.set({"index": index}) + cleanup(doc_ref.delete) + + iter = col.limit(5)._chunkify(9) + assert len(next(iter)) == 5 + + +def test_chunked_and_recursive(client, cleanup): + col_id = f"chunked-recursive-test{UNIQUE_RESOURCE_ID}" + documents = [ + { + "data": {"name": "Root-1"}, + "subcollections": { + "children": [ + {"data": {"name": f"Root-1--Child-{index + 1}"}} + for index in range(5) + ] + }, + }, + { + "data": {"name": "Root-2"}, + "subcollections": { + "children": [ + {"data": {"name": f"Root-2--Child-{index + 1}"}} + for index in range(5) + ] + }, + }, + ] + _persist_documents(client, col_id, documents, cleanup) + collection_ref = client.collection(col_id) + iter = collection_ref.recursive()._chunkify(5) + + page_1_ids = [ + "Root-1", + "Root-1--Child-1", + "Root-1--Child-2", + "Root-1--Child-3", + "Root-1--Child-4", + ] + assert [doc.id for doc in next(iter)] == page_1_ids + + page_2_ids = [ + "Root-1--Child-5", + "Root-2", + "Root-2--Child-1", + "Root-2--Child-2", + "Root-2--Child-3", + ] + assert [doc.id for doc in next(iter)] == page_2_ids + + page_3_ids = ["Root-2--Child-4", "Root-2--Child-5"] + assert [doc.id for doc in next(iter)] == page_3_ids + + def test_watch_query_order(client, cleanup): db = client collection_ref = db.collection("users") diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index a4db4e75ff..b7c562fd3d 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -18,6 +18,7 @@ import math import pytest import operator +from typing import Callable, Dict, List, Optional from google.oauth2 import service_account @@ -1094,67 +1095,159 @@ async def test_batch(client, cleanup): assert not (await document3.get()).exists -async def test_recursive_query(client, cleanup): +async def _persist_documents( + client: firestore.AsyncClient, + collection_name: str, + documents: List[Dict], + cleanup: Optional[Callable] = None, +): + """Assuming `documents` is a recursive list of dictionaries representing + documents and subcollections, this method writes all of those through + `client.collection(...).document(...).create()`. - philosophers = [ + `documents` must be of this structure: + ```py + documents = [ { - "data": {"name": "Socrates", "favoriteCity": "Athens"}, - "subcollections": { - "pets": [{"name": "Scruffy"}, {"name": "Snowflake"}], - "hobbies": [{"name": "pontificating"}, {"name": "journaling"}], - "philosophers": [{"name": "Aristotle"}, {"name": "Plato"}], - }, + # Required key + "data": , + + # Optional key + "subcollections": , }, - { - "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, - "subcollections": { - "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], - "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], - }, + ... + ] + ``` + """ + for block in documents: + col_ref = client.collection(collection_name) + document_id: str = block["data"]["name"] + doc_ref = col_ref.document(document_id) + await doc_ref.set(block["data"]) + if cleanup is not None: + cleanup(doc_ref.delete) + + if "subcollections" in block: + for subcollection_name, inner_blocks in block["subcollections"].items(): + await _persist_documents( + client, + f"{collection_name}/{document_id}/{subcollection_name}", + inner_blocks, + ) + + +# documents compatible with `_persist_documents` +philosophers_data_set = [ + { + "data": {"name": "Socrates", "favoriteCity": "Athens"}, + "subcollections": { + "pets": [{"data": {"name": "Scruffy"}}, {"data": {"name": "Snowflake"}}], + "hobbies": [ + {"data": {"name": "pontificating"}}, + {"data": {"name": "journaling"}}, + ], + "philosophers": [ + {"data": {"name": "Aristotle"}}, + {"data": {"name": "Plato"}}, + ], }, - { - "data": {"name": "Plato", "favoriteCity": "Corinth"}, - "subcollections": { - "pets": [{"name": "Cuddles"}, {"name": "Sergeant-Puppers"}], - "hobbies": [{"name": "abstraction"}, {"name": "hypotheticals"}], - }, + }, + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"data": {"name": "Floof-Boy"}}, {"data": {"name": "Doggy-Dog"}}], + "hobbies": [ + {"data": {"name": "questioning-stuff"}}, + {"data": {"name": "meditation"}}, + ], }, - ] + }, + { + "data": {"name": "Plato", "favoriteCity": "Corinth"}, + "subcollections": { + "pets": [ + {"data": {"name": "Cuddles"}}, + {"data": {"name": "Sergeant-Puppers"}}, + ], + "hobbies": [ + {"data": {"name": "abstraction"}}, + {"data": {"name": "hypotheticals"}}, + ], + }, + }, +] - db = client - collection_ref = db.collection("philosophers") - for philosopher in philosophers: - ref = collection_ref.document( - f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}-async" - ) - await ref.set(philosopher["data"]) - cleanup(ref.delete) - for col_name, entries in philosopher["subcollections"].items(): - sub_col = ref.collection(col_name) - for entry in entries: - inner_doc_ref = sub_col.document(entry["name"]) - await inner_doc_ref.set(entry) - cleanup(inner_doc_ref.delete) - - ids = [ - doc.id for doc in await db.collection_group("philosophers").recursive().get() + +async def _do_recursive_delete_with_bulk_writer(client, bulk_writer): + philosophers = [philosophers_data_set[0]] + await _persist_documents( + client, f"philosophers-async{UNIQUE_RESOURCE_ID}", philosophers + ) + + doc_paths = [ + "", + "/pets/Scruffy", + "/pets/Snowflake", + "/hobbies/pontificating", + "/hobbies/journaling", + "/philosophers/Aristotle", + "/philosophers/Plato", ] + # Assert all documents were created so that when they're missing after the + # delete, we're actually testing something. + collection_ref = client.collection(f"philosophers-async{UNIQUE_RESOURCE_ID}") + for path in doc_paths: + snapshot = await collection_ref.document(f"Socrates{path}").get() + assert snapshot.exists, f"Snapshot at Socrates{path} should have been created" + + # Now delete. + num_deleted = await client.recursive_delete(collection_ref, bulk_writer=bulk_writer) + assert num_deleted == len(doc_paths) + + # Now they should all be missing + for path in doc_paths: + snapshot = await collection_ref.document(f"Socrates{path}").get() + assert ( + not snapshot.exists + ), f"Snapshot at Socrates{path} should have been deleted" + + +async def test_async_recursive_delete_parallelized(client, cleanup): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode + + bw = client.bulk_writer(options=BulkWriterOptions(mode=SendMode.parallel)) + await _do_recursive_delete_with_bulk_writer(client, bw) + + +async def test_async_recursive_delete_serialized(client, cleanup): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode + + bw = client.bulk_writer(options=BulkWriterOptions(mode=SendMode.serial)) + await _do_recursive_delete_with_bulk_writer(client, bw) + + +async def test_recursive_query(client, cleanup): + col_id: str = f"philosophers-recursive-async-query{UNIQUE_RESOURCE_ID}" + await _persist_documents(client, col_id, philosophers_data_set, cleanup) + + ids = [doc.id for doc in await client.collection_group(col_id).recursive().get()] + expected_ids = [ # Aristotle doc and subdocs - f"Aristotle{UNIQUE_RESOURCE_ID}-async", + "Aristotle", "meditation", "questioning-stuff", "Doggy-Dog", "Floof-Boy", # Plato doc and subdocs - f"Plato{UNIQUE_RESOURCE_ID}-async", + "Plato", "abstraction", "hypotheticals", "Cuddles", "Sergeant-Puppers", # Socrates doc and subdocs - f"Socrates{UNIQUE_RESOURCE_ID}-async", + "Socrates", "journaling", "pontificating", "Scruffy", @@ -1173,36 +1266,12 @@ async def test_recursive_query(client, cleanup): async def test_nested_recursive_query(client, cleanup): + col_id: str = f"philosophers-nested-recursive-async-query{UNIQUE_RESOURCE_ID}" + await _persist_documents(client, col_id, philosophers_data_set, cleanup) - philosophers = [ - { - "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, - "subcollections": { - "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], - "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], - }, - }, - ] - - db = client - collection_ref = db.collection("philosophers") - for philosopher in philosophers: - ref = collection_ref.document( - f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}-async" - ) - await ref.set(philosopher["data"]) - cleanup(ref.delete) - for col_name, entries in philosopher["subcollections"].items(): - sub_col = ref.collection(col_name) - for entry in entries: - inner_doc_ref = sub_col.document(entry["name"]) - await inner_doc_ref.set(entry) - cleanup(inner_doc_ref.delete) - - aristotle = collection_ref.document(f"Aristotle{UNIQUE_RESOURCE_ID}-async") - ids = [ - doc.id for doc in await aristotle.collection("pets")._query().recursive().get() - ] + collection_ref = client.collection(col_id) + aristotle = collection_ref.document("Aristotle") + ids = [doc.id for doc in await aristotle.collection("pets").recursive().get()] expected_ids = [ # Aristotle pets @@ -1219,6 +1288,84 @@ async def test_nested_recursive_query(client, cleanup): assert ids[index] == expected_ids[index], error_msg +async def test_chunked_query(client, cleanup): + col = client.collection(f"async-chunked-test{UNIQUE_RESOURCE_ID}") + for index in range(10): + doc_ref = col.document(f"document-{index + 1}") + await doc_ref.set({"index": index}) + cleanup(doc_ref.delete) + + lengths: List[int] = [len(chunk) async for chunk in col._chunkify(3)] + assert len(lengths) == 4 + assert lengths[0] == 3 + assert lengths[1] == 3 + assert lengths[2] == 3 + assert lengths[3] == 1 + + +async def test_chunked_query_smaller_limit(client, cleanup): + col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") + for index in range(10): + doc_ref = col.document(f"document-{index + 1}") + await doc_ref.set({"index": index}) + cleanup(doc_ref.delete) + + lengths: List[int] = [len(chunk) async for chunk in col.limit(5)._chunkify(9)] + assert len(lengths) == 1 + assert lengths[0] == 5 + + +async def test_chunked_and_recursive(client, cleanup): + col_id = f"chunked-async-recursive-test{UNIQUE_RESOURCE_ID}" + documents = [ + { + "data": {"name": "Root-1"}, + "subcollections": { + "children": [ + {"data": {"name": f"Root-1--Child-{index + 1}"}} + for index in range(5) + ] + }, + }, + { + "data": {"name": "Root-2"}, + "subcollections": { + "children": [ + {"data": {"name": f"Root-2--Child-{index + 1}"}} + for index in range(5) + ] + }, + }, + ] + await _persist_documents(client, col_id, documents, cleanup) + collection_ref = client.collection(col_id) + iter = collection_ref.recursive()._chunkify(5) + + pages = [page async for page in iter] + doc_ids = [[doc.id for doc in page] for page in pages] + + page_1_ids = [ + "Root-1", + "Root-1--Child-1", + "Root-1--Child-2", + "Root-1--Child-3", + "Root-1--Child-4", + ] + assert doc_ids[0] == page_1_ids + + page_2_ids = [ + "Root-1--Child-5", + "Root-2", + "Root-2--Child-1", + "Root-2--Child-2", + "Root-2--Child-3", + ] + assert doc_ids[1] == page_2_ids + + page_3_ids = ["Root-2--Child-4", "Root-2--Child-5"] + assert doc_ids[2] == page_3_ids + + async def _chain(*iterators): """Asynchronous reimplementation of `itertools.chain`.""" for iterator in iterators: diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index bb7a51dd83..598da81eab 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -18,6 +18,8 @@ import aiounittest import mock +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse from tests.unit.v1.test__helpers import AsyncIter, AsyncMock @@ -388,6 +390,110 @@ def test_sync_copy(self): # Multiple calls to this method should return the same cached instance. self.assertIs(client._to_sync_copy(), client._to_sync_copy()) + @pytest.mark.asyncio + async def test_recursive_delete(self): + client = self._make_default_one() + client._firestore_api_internal = AsyncMock(spec=["run_query"]) + collection_ref = client.collection("my_collection") + + results = [] + for index in range(10): + results.append( + RunQueryResponse(document=Document(name=f"{collection_ref.id}/{index}")) + ) + + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] + + def _get_chunk(*args, **kwargs): + return AsyncIter(items=chunks.pop(0)) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = await client.recursive_delete( + collection_ref, bulk_writer=bulk_writer, chunk_size=3 + ) + self.assertEqual(num_deleted, len(results)) + + @pytest.mark.asyncio + async def test_recursive_delete_from_document(self): + client = self._make_default_one() + client._firestore_api_internal = mock.Mock( + spec=["run_query", "list_collection_ids"] + ) + collection_ref = client.collection("my_collection") + + collection_1_id: str = "collection_1_id" + collection_2_id: str = "collection_2_id" + + parent_doc = collection_ref.document("parent") + + collection_1_results = [] + collection_2_results = [] + + for index in range(10): + collection_1_results.append( + RunQueryResponse(document=Document(name=f"{collection_1_id}/{index}"),), + ) + + collection_2_results.append( + RunQueryResponse(document=Document(name=f"{collection_2_id}/{index}"),), + ) + + col_1_chunks = [ + collection_1_results[:3], + collection_1_results[3:6], + collection_1_results[6:9], + collection_1_results[9:], + ] + + col_2_chunks = [ + collection_2_results[:3], + collection_2_results[3:6], + collection_2_results[6:9], + collection_2_results[9:], + ] + + async def _get_chunk(*args, **kwargs): + start_at = ( + kwargs["request"]["structured_query"].start_at.values[0].reference_value + ) + + if collection_1_id in start_at: + return AsyncIter(col_1_chunks.pop(0)) + return AsyncIter(col_2_chunks.pop(0)) + + async def _get_collections(*args, **kwargs): + return AsyncIter([collection_1_id, collection_2_id]) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + client._firestore_api_internal.list_collection_ids.side_effect = ( + _get_collections + ) + + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = await client.recursive_delete( + parent_doc, bulk_writer=bulk_writer, chunk_size=3 + ) + + expected_len = len(collection_1_results) + len(collection_2_results) + 1 + self.assertEqual(num_deleted, expected_len) + + @pytest.mark.asyncio + async def test_recursive_delete_raises(self): + client = self._make_default_one() + with self.assertRaises(TypeError): + await client.recursive_delete(object()) + def test_batch(self): from google.cloud.firestore_v1.async_batch import AsyncWriteBatch diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 33006e2542..1955ca52de 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse import pytest import types import aiounittest @@ -204,6 +206,42 @@ async def test_add_w_retry_timeout(self): timeout = 123.0 await self._add_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio + async def test_chunkify(self): + client = _make_client() + col = client.collection("my-collection") + + client._firestore_api_internal = mock.Mock(spec=["run_query"]) + + results = [] + for index in range(10): + results.append( + RunQueryResponse( + document=Document( + name=f"projects/project-project/databases/(default)/documents/my-collection/{index}", + ), + ), + ) + + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] + + async def _get_chunk(*args, **kwargs): + return AsyncIter(chunks.pop(0)) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + + counter = 0 + expected_lengths = [3, 3, 3, 1] + async for chunk in col._chunkify(3): + msg = f"Expected chunk of length {expected_lengths[counter]} at index {counter}. Saw {len(chunk)}." + self.assertEqual(len(chunk), expected_lengths[counter], msg) + counter += 1 + @pytest.mark.asyncio async def _list_documents_helper(self, page_size=None, retry=None, timeout=None): from google.cloud.firestore_v1 import _helpers diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 64feddaf4e..4d18d551b3 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse import pytest import types import aiounittest @@ -469,6 +471,28 @@ async def test_stream_w_collection_group(self): metadata=client._rpc_metadata, ) + @pytest.mark.asyncio + async def test_unnecessary_chunkify(self): + client = _make_client() + + firestore_api = AsyncMock(spec=["run_query"]) + firestore_api.run_query.return_value = AsyncIter( + [ + RunQueryResponse( + document=Document( + name=f"projects/project-project/databases/(default)/documents/asdf/{index}", + ), + ) + for index in range(5) + ] + ) + client._firestore_api_internal = firestore_api + + query = client.collection("asdf")._query() + + async for chunk in query.limit(5)._chunkify(10): + self.assertEqual(len(chunk), 5) + class TestCollectionGroup(aiounittest.AsyncTestCase): @staticmethod diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index a46839ac59..5fbc73793e 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -17,6 +17,8 @@ import unittest import mock +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse class TestClient(unittest.TestCase): @@ -360,6 +362,104 @@ def test_get_all_unknown_result(self): metadata=client._rpc_metadata, ) + def test_recursive_delete(self): + client = self._make_default_one() + client._firestore_api_internal = mock.Mock(spec=["run_query"]) + collection_ref = client.collection("my_collection") + + results = [] + for index in range(10): + results.append( + RunQueryResponse(document=Document(name=f"{collection_ref.id}/{index}")) + ) + + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] + + def _get_chunk(*args, **kwargs): + return iter(chunks.pop(0)) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = client.recursive_delete( + collection_ref, bulk_writer=bulk_writer, chunk_size=3 + ) + self.assertEqual(num_deleted, len(results)) + + def test_recursive_delete_from_document(self): + client = self._make_default_one() + client._firestore_api_internal = mock.Mock( + spec=["run_query", "list_collection_ids"] + ) + collection_ref = client.collection("my_collection") + + collection_1_id: str = "collection_1_id" + collection_2_id: str = "collection_2_id" + + parent_doc = collection_ref.document("parent") + + collection_1_results = [] + collection_2_results = [] + + for index in range(10): + collection_1_results.append( + RunQueryResponse(document=Document(name=f"{collection_1_id}/{index}"),), + ) + + collection_2_results.append( + RunQueryResponse(document=Document(name=f"{collection_2_id}/{index}"),), + ) + + col_1_chunks = [ + collection_1_results[:3], + collection_1_results[3:6], + collection_1_results[6:9], + collection_1_results[9:], + ] + + col_2_chunks = [ + collection_2_results[:3], + collection_2_results[3:6], + collection_2_results[6:9], + collection_2_results[9:], + ] + + def _get_chunk(*args, **kwargs): + start_at = ( + kwargs["request"]["structured_query"].start_at.values[0].reference_value + ) + + if collection_1_id in start_at: + return iter(col_1_chunks.pop(0)) + return iter(col_2_chunks.pop(0)) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + client._firestore_api_internal.list_collection_ids.return_value = [ + collection_1_id, + collection_2_id, + ] + + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = client.recursive_delete( + parent_doc, bulk_writer=bulk_writer, chunk_size=3 + ) + + expected_len = len(collection_1_results) + len(collection_2_results) + 1 + self.assertEqual(num_deleted, expected_len) + + def test_recursive_delete_raises(self): + client = self._make_default_one() + self.assertRaises(TypeError, client.recursive_delete, object()) + def test_batch(self): from google.cloud.firestore_v1.batch import WriteBatch diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 5885a29d97..cfefeb9e61 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse import types import unittest @@ -355,3 +357,38 @@ def test_recursive(self): col = self._make_one("collection") self.assertIsInstance(col.recursive(), Query) + + def test_chunkify(self): + client = _test_helpers.make_client() + col = client.collection("my-collection") + + client._firestore_api_internal = mock.Mock(spec=["run_query"]) + + results = [] + for index in range(10): + results.append( + RunQueryResponse( + document=Document( + name=f"projects/project-project/databases/(default)/documents/my-collection/{index}", + ), + ), + ) + + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] + + def _get_chunk(*args, **kwargs): + return iter(chunks.pop(0)) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + + counter = 0 + expected_lengths = [3, 3, 3, 1] + for chunk in col._chunkify(3): + msg = f"Expected chunk of length {expected_lengths[counter]} at index {counter}. Saw {len(chunk)}." + self.assertEqual(len(chunk), expected_lengths[counter], msg) + counter += 1 diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index 91172b120b..ea28969a84 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse import types import unittest @@ -460,6 +462,40 @@ def test_on_snapshot(self, watch): query.on_snapshot(None) watch.for_query.assert_called_once() + def test_unnecessary_chunkify(self): + client = _make_client() + + firestore_api = mock.Mock(spec=["run_query"]) + firestore_api.run_query.return_value = iter( + [ + RunQueryResponse( + document=Document( + name=f"projects/project-project/databases/(default)/documents/asdf/{index}", + ), + ) + for index in range(5) + ] + ) + client._firestore_api_internal = firestore_api + + query = client.collection("asdf")._query() + + for chunk in query.limit(5)._chunkify(10): + self.assertEqual(len(chunk), 5) + + def test__resolve_chunk_size(self): + # With a global limit + query = _make_client().collection("asdf").limit(5) + self.assertEqual(query._resolve_chunk_size(3, 10), 2) + self.assertEqual(query._resolve_chunk_size(3, 1), 1) + self.assertEqual(query._resolve_chunk_size(3, 2), 2) + + # With no limit + query = _make_client().collection("asdf")._query() + self.assertEqual(query._resolve_chunk_size(3, 10), 10) + self.assertEqual(query._resolve_chunk_size(3, 1), 1) + self.assertEqual(query._resolve_chunk_size(3, 2), 2) + class TestCollectionGroup(unittest.TestCase): @staticmethod From d7ea274e6886ad70926177d3ee8f59ef9c4a7bdc Mon Sep 17 00:00:00 2001 From: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Date: Wed, 18 Aug 2021 07:40:38 -0600 Subject: [PATCH 16/19] chore: generate python samples templates in owlbot.py (#427) Generate python samples templates in owlbot.py --- owlbot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/owlbot.py b/owlbot.py index 10f5894422..f6cc418d63 100644 --- a/owlbot.py +++ b/owlbot.py @@ -138,6 +138,7 @@ def update_fixup_scripts(library): cov_level=100, split_system_tests=True, ) +python.py_samples(skip_readmes=True) s.move(templated_files) From b56fd2803c7d96832112ca6d9f4810dfd862bdb9 Mon Sep 17 00:00:00 2001 From: Anthonios Partheniou Date: Wed, 18 Aug 2021 11:38:57 -0400 Subject: [PATCH 17/19] chore: add missing import in owlbot.py (#428) --- owlbot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/owlbot.py b/owlbot.py index f6cc418d63..2415ba8b07 100644 --- a/owlbot.py +++ b/owlbot.py @@ -18,6 +18,7 @@ import synthtool as s from synthtool import gcp +from synthtool.languages import python common = gcp.CommonTemplates() From cf9cddbfeb579ee7125532b09b7b785a6d131e5c Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 18 Aug 2021 12:12:09 -0400 Subject: [PATCH 18/19] tests: revert testing against prerelease deps on Python 3.9 (#426) Reverts googleapis/python-firestore#415 Consensus from today's meeting is that testing prereleases of third-party dependencies needs to happen outside the normal `presubmit` path. --- testing/constraints-3.9.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index 6d34489a53..e69de29bb2 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -1,2 +0,0 @@ -# Allow prerelease requirements ---pre From 5950e2bb5fe975a35c3bc4d613aa0fddd247694e Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 15:41:02 -0400 Subject: [PATCH 19/19] chore: release 2.3.0 (#418) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> Co-authored-by: Tres Seaver --- CHANGELOG.md | 22 ++++++++++++++++++++++ setup.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2076e7e9df..4500c6c1e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,28 @@ [1]: https://pypi.org/project/google-cloud-firestore/#history +## [2.3.0](https://www.github.com/googleapis/python-firestore/compare/v2.2.0...v2.3.0) (2021-08-18) + + +### Features + +* add bulk writer ([#396](https://www.github.com/googleapis/python-firestore/issues/396)) ([98a7753](https://www.github.com/googleapis/python-firestore/commit/98a7753f05240a2a75b9ffd42b7a148c65a6e87f)) +* add recursive delete ([#420](https://www.github.com/googleapis/python-firestore/issues/420)) ([813a57b](https://www.github.com/googleapis/python-firestore/commit/813a57b1070a1f6ac41d02897fab33f8039b83f9)) +* add support for recursive queries ([#407](https://www.github.com/googleapis/python-firestore/issues/407)) ([eb45a36](https://www.github.com/googleapis/python-firestore/commit/eb45a36e6c06b642106e061a32bfc119eb7e5bf0)) + + +### Bug Fixes + +* enable self signed jwt for grpc ([#405](https://www.github.com/googleapis/python-firestore/issues/405)) ([8703b48](https://www.github.com/googleapis/python-firestore/commit/8703b48c45e7bb742a794cad9597740c44182f81)) +* use insecure grpc channels with emulator ([#402](https://www.github.com/googleapis/python-firestore/issues/402)) ([4381ad5](https://www.github.com/googleapis/python-firestore/commit/4381ad503ca3e83510b876281fc768d00d40d499)) +* remove unused requirement pytz ([#422](https://www.github.com/googleapis/python-firestore/issues/422)) ([539c1d7](https://www.github.com/googleapis/python-firestore/commit/539c1d719191eb0ae3a49290c26b628de7c27cd5)) + + +### Documentation + +* added generated docs for Bundles ([#416](https://www.github.com/googleapis/python-firestore/issues/416)) ([0176cc7](https://www.github.com/googleapis/python-firestore/commit/0176cc7fef8752433b5c2496046d3a56557eb824)) +* fixed broken links to devsite ([#417](https://www.github.com/googleapis/python-firestore/issues/417)) ([1adfc81](https://www.github.com/googleapis/python-firestore/commit/1adfc81237c4ddee665e81f1beaef808cddb860e)) + ## [2.2.0](https://www.github.com/googleapis/python-firestore/compare/v2.1.3...v2.2.0) (2021-07-22) diff --git a/setup.py b/setup.py index 50de98e267..d183dd391a 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ name = "google-cloud-firestore" description = "Google Cloud Firestore API client library" -version = "2.2.0" +version = "2.3.0" release_status = "Development Status :: 5 - Production/Stable" dependencies = [ # NOTE: Maintainers, please do not require google-api-core>=2.x.x