Skip to content
This repository was archived by the owner on Jul 6, 2023. It is now read-only.

Commit f51a651

Browse files
feat: add context manager support in client (#48)
- [ ] Regenerate this pull request now. chore: fix docstring for first attribute of protos committer: @busunkim96 PiperOrigin-RevId: 401271153 Source-Link: googleapis/googleapis@787f8c9 Source-Link: https://github.com/googleapis/googleapis-gen/commit/81decffe9fc72396a8153e756d1d67a6eecfd620 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODFkZWNmZmU5ZmM3MjM5NmE4MTUzZTc1NmQxZDY3YTZlZWNmZDYyMCJ9
1 parent 048892b commit f51a651

File tree

7 files changed

+100
-4
lines changed

7 files changed

+100
-4
lines changed

google/cloud/tpu_v1/services/tpu/async_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,12 @@ async def get_accelerator_type(
952952
# Done; return the response.
953953
return response
954954

955+
async def __aenter__(self):
956+
return self
957+
958+
async def __aexit__(self, exc_type, exc, tb):
959+
await self.transport.close()
960+
955961

956962
try:
957963
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/tpu_v1/services/tpu/client.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,7 @@ def __init__(
383383
client_cert_source_for_mtls=client_cert_source_func,
384384
quota_project_id=client_options.quota_project_id,
385385
client_info=client_info,
386-
always_use_jwt_access=(
387-
Transport == type(self).get_transport_class("grpc")
388-
or Transport == type(self).get_transport_class("grpc_asyncio")
389-
),
386+
always_use_jwt_access=True,
390387
)
391388

392389
def list_nodes(
@@ -1184,6 +1181,19 @@ def get_accelerator_type(
11841181
# Done; return the response.
11851182
return response
11861183

1184+
def __enter__(self):
1185+
return self
1186+
1187+
def __exit__(self, type, value, traceback):
1188+
"""Releases underlying transport's resources.
1189+
1190+
.. warning::
1191+
ONLY use as a context manager if the transport is NOT shared
1192+
with other clients! Exiting the with block will CLOSE the transport
1193+
and may cause errors in other clients!
1194+
"""
1195+
self.transport.close()
1196+
11871197

11881198
try:
11891199
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/tpu_v1/services/tpu/transports/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ def _prep_wrapped_messages(self, client_info):
197197
),
198198
}
199199

200+
def close(self):
201+
"""Closes resources associated with the transport.
202+
203+
.. warning::
204+
Only call this method if the transport is NOT shared
205+
with other clients - this may cause errors in other clients!
206+
"""
207+
raise NotImplementedError()
208+
200209
@property
201210
def operations_client(self) -> operations_v1.OperationsClient:
202211
"""Return the client designed to process long-running operations."""

google/cloud/tpu_v1/services/tpu/transports/grpc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,5 +532,8 @@ def get_accelerator_type(
532532
)
533533
return self._stubs["get_accelerator_type"]
534534

535+
def close(self):
536+
self.grpc_channel.close()
537+
535538

536539
__all__ = ("TpuGrpcTransport",)

google/cloud/tpu_v1/services/tpu/transports/grpc_asyncio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,5 +544,8 @@ def get_accelerator_type(
544544
)
545545
return self._stubs["get_accelerator_type"]
546546

547+
def close(self):
548+
return self.grpc_channel.close()
549+
547550

548551
__all__ = ("TpuGrpcAsyncIOTransport",)

google/cloud/tpu_v1/types/cloud_tpu.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
class SchedulingConfig(proto.Message):
5050
r"""Sets the scheduling options for this node.
51+
5152
Attributes:
5253
preemptible (bool):
5354
Defines whether the node is preemptible.
@@ -62,6 +63,7 @@ class SchedulingConfig(proto.Message):
6263

6364
class NetworkEndpoint(proto.Message):
6465
r"""A network endpoint over which a TPU worker can be reached.
66+
6567
Attributes:
6668
ip_address (str):
6769
The IP address of this network endpoint.
@@ -75,6 +77,7 @@ class NetworkEndpoint(proto.Message):
7577

7678
class Node(proto.Message):
7779
r"""A TPU instance.
80+
7881
Attributes:
7982
name (str):
8083
Output only. Immutable. The name of the TPU
@@ -225,6 +228,7 @@ class ApiVersion(proto.Enum):
225228

226229
class ListNodesRequest(proto.Message):
227230
r"""Request for [ListNodes][google.cloud.tpu.v1.Tpu.ListNodes].
231+
228232
Attributes:
229233
parent (str):
230234
Required. The parent resource name.
@@ -242,6 +246,7 @@ class ListNodesRequest(proto.Message):
242246

243247
class ListNodesResponse(proto.Message):
244248
r"""Response for [ListNodes][google.cloud.tpu.v1.Tpu.ListNodes].
249+
245250
Attributes:
246251
nodes (Sequence[google.cloud.tpu_v1.types.Node]):
247252
The listed nodes.
@@ -262,6 +267,7 @@ def raw_page(self):
262267

263268
class GetNodeRequest(proto.Message):
264269
r"""Request for [GetNode][google.cloud.tpu.v1.Tpu.GetNode].
270+
265271
Attributes:
266272
name (str):
267273
Required. The resource name.
@@ -272,6 +278,7 @@ class GetNodeRequest(proto.Message):
272278

273279
class CreateNodeRequest(proto.Message):
274280
r"""Request for [CreateNode][google.cloud.tpu.v1.Tpu.CreateNode].
281+
275282
Attributes:
276283
parent (str):
277284
Required. The parent resource name.
@@ -288,6 +295,7 @@ class CreateNodeRequest(proto.Message):
288295

289296
class DeleteNodeRequest(proto.Message):
290297
r"""Request for [DeleteNode][google.cloud.tpu.v1.Tpu.DeleteNode].
298+
291299
Attributes:
292300
name (str):
293301
Required. The resource name.
@@ -298,6 +306,7 @@ class DeleteNodeRequest(proto.Message):
298306

299307
class ReimageNodeRequest(proto.Message):
300308
r"""Request for [ReimageNode][google.cloud.tpu.v1.Tpu.ReimageNode].
309+
301310
Attributes:
302311
name (str):
303312
The resource name.
@@ -311,6 +320,7 @@ class ReimageNodeRequest(proto.Message):
311320

312321
class StopNodeRequest(proto.Message):
313322
r"""Request for [StopNode][google.cloud.tpu.v1.Tpu.StopNode].
323+
314324
Attributes:
315325
name (str):
316326
The resource name.
@@ -321,6 +331,7 @@ class StopNodeRequest(proto.Message):
321331

322332
class StartNodeRequest(proto.Message):
323333
r"""Request for [StartNode][google.cloud.tpu.v1.Tpu.StartNode].
334+
324335
Attributes:
325336
name (str):
326337
The resource name.
@@ -331,6 +342,7 @@ class StartNodeRequest(proto.Message):
331342

332343
class TensorFlowVersion(proto.Message):
333344
r"""A tensorflow version that a Node can be configured with.
345+
334346
Attributes:
335347
name (str):
336348
The resource name.
@@ -405,6 +417,7 @@ def raw_page(self):
405417

406418
class AcceleratorType(proto.Message):
407419
r"""A accelerator type that a Node can be configured with.
420+
408421
Attributes:
409422
name (str):
410423
The resource name.
@@ -479,6 +492,7 @@ def raw_page(self):
479492

480493
class OperationMetadata(proto.Message):
481494
r"""Metadata describing an [Operation][google.longrunning.Operation]
495+
482496
Attributes:
483497
create_time (google.protobuf.timestamp_pb2.Timestamp):
484498
The time the operation was created.
@@ -510,6 +524,7 @@ class OperationMetadata(proto.Message):
510524

511525
class Symptom(proto.Message):
512526
r"""A Symptom instance.
527+
513528
Attributes:
514529
create_time (google.protobuf.timestamp_pb2.Timestamp):
515530
Timestamp when the Symptom is created.

tests/unit/gapic/tpu_v1/test_tpu.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.api_core import grpc_helpers_async
3333
from google.api_core import operation_async # type: ignore
3434
from google.api_core import operations_v1
35+
from google.api_core import path_template
3536
from google.auth import credentials as ga_credentials
3637
from google.auth.exceptions import MutualTLSChannelError
3738
from google.cloud.tpu_v1.services.tpu import TpuAsyncClient
@@ -3100,6 +3101,9 @@ def test_tpu_base_transport():
31003101
with pytest.raises(NotImplementedError):
31013102
getattr(transport, method)(request=object())
31023103

3104+
with pytest.raises(NotImplementedError):
3105+
transport.close()
3106+
31033107
# Additionally, the LRO client (a property) should
31043108
# also raise NotImplementedError
31053109
with pytest.raises(NotImplementedError):
@@ -3633,3 +3637,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
36333637
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
36343638
)
36353639
prep.assert_called_once_with(client_info)
3640+
3641+
3642+
@pytest.mark.asyncio
3643+
async def test_transport_close_async():
3644+
client = TpuAsyncClient(
3645+
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
3646+
)
3647+
with mock.patch.object(
3648+
type(getattr(client.transport, "grpc_channel")), "close"
3649+
) as close:
3650+
async with client:
3651+
close.assert_not_called()
3652+
close.assert_called_once()
3653+
3654+
3655+
def test_transport_close():
3656+
transports = {
3657+
"grpc": "_grpc_channel",
3658+
}
3659+
3660+
for transport, close_name in transports.items():
3661+
client = TpuClient(
3662+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
3663+
)
3664+
with mock.patch.object(
3665+
type(getattr(client.transport, close_name)), "close"
3666+
) as close:
3667+
with client:
3668+
close.assert_not_called()
3669+
close.assert_called_once()
3670+
3671+
3672+
def test_client_ctx():
3673+
transports = [
3674+
"grpc",
3675+
]
3676+
for transport in transports:
3677+
client = TpuClient(
3678+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
3679+
)
3680+
# Test client calls underlying transport.
3681+
with mock.patch.object(type(client.transport), "close") as close:
3682+
close.assert_not_called()
3683+
with client:
3684+
pass
3685+
close.assert_called()

0 commit comments

Comments
 (0)