Skip to content

Commit 8321826

Browse files
yeesiancopybara-github
authored andcommitted
feat: GenAI SDK client - Add support for context specs when creating agent engine instances
PiperOrigin-RevId: 783344749
1 parent df2390e commit 8321826

File tree

4 files changed

+170
-74
lines changed

4 files changed

+170
-74
lines changed

tests/unit/vertexai/genai/replays/test_create_agent_engine.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,34 @@ def test_create_config_lightweight(client):
3838
}
3939

4040

41+
def test_create_with_context_spec(client):
42+
project = "test-project"
43+
location = "us-central1"
44+
parent = f"projects/{project}/locations/{location}"
45+
generation_model = f"{parent}/publishers/google/models/gemini-2.0-flash-001"
46+
embedding_model = f"{parent}/publishers/google/models/text-embedding-005"
47+
48+
agent_engine = client.agent_engines.create(
49+
config={
50+
"context_spec": {
51+
"memory_bank_config": {
52+
"generation_config": {"model": generation_model},
53+
"similarity_search_config": {
54+
"embedding_model": embedding_model,
55+
},
56+
},
57+
},
58+
"http_options": {"api_version": "v1beta1"},
59+
},
60+
)
61+
agent_engine = client.agent_engines.get(name=agent_engine.api_resource.name)
62+
memory_bank_config = agent_engine.api_resource.context_spec.memory_bank_config
63+
assert memory_bank_config.generation_config.model == generation_model
64+
assert (
65+
memory_bank_config.similarity_search_config.embedding_model == embedding_model
66+
)
67+
68+
4169
pytestmark = pytest_helper.setup(
4270
file=__file__,
4371
globals_for_file=globals(),

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,7 @@ def test_create_agent_engine_with_env_vars_dict(
11101110
gcs_dir_name=None,
11111111
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
11121112
env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT,
1113+
context_spec=None,
11131114
)
11141115
request_mock.assert_called_with(
11151116
"post",

vertexai/_genai/agent_engines.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,21 @@ def _ReasoningEngineSpec_to_vertex(
6363
return to_object
6464

6565

66+
def _ReasoningEngineContextSpec_to_vertex(
67+
from_object: Union[dict[str, Any], object],
68+
parent_object: Optional[dict[str, Any]] = None,
69+
) -> dict[str, Any]:
70+
to_object: dict[str, Any] = {}
71+
if getv(from_object, ["memory_bank_config"]) is not None:
72+
setv(
73+
to_object,
74+
["memoryBankConfig"],
75+
getv(from_object, ["memory_bank_config"]),
76+
)
77+
78+
return to_object
79+
80+
6681
def _CreateAgentEngineConfig_to_vertex(
6782
from_object: Union[dict[str, Any], object],
6883
parent_object: Optional[dict[str, Any]] = None,
@@ -82,6 +97,15 @@ def _CreateAgentEngineConfig_to_vertex(
8297
_ReasoningEngineSpec_to_vertex(getv(from_object, ["spec"]), to_object),
8398
)
8499

100+
if getv(from_object, ["context_spec"]) is not None:
101+
setv(
102+
parent_object,
103+
["contextSpec"],
104+
_ReasoningEngineContextSpec_to_vertex(
105+
getv(from_object, ["context_spec"]), to_object
106+
),
107+
)
108+
85109
return to_object
86110

87111

@@ -550,6 +574,15 @@ def _UpdateAgentEngineConfig_to_vertex(
550574
_ReasoningEngineSpec_to_vertex(getv(from_object, ["spec"]), to_object),
551575
)
552576

577+
if getv(from_object, ["context_spec"]) is not None:
578+
setv(
579+
parent_object,
580+
["contextSpec"],
581+
_ReasoningEngineContextSpec_to_vertex(
582+
getv(from_object, ["context_spec"]), to_object
583+
),
584+
)
585+
553586
if getv(from_object, ["update_mask"]) is not None:
554587
setv(
555588
parent_object,
@@ -1976,6 +2009,10 @@ def create(
19762009
"config must be a dict or AgentEngineConfig, but got"
19772010
f" {type(config)}."
19782011
)
2012+
context_spec = config.context_spec
2013+
if context_spec is not None:
2014+
# Conversion to a dict for _create_config
2015+
context_spec = context_spec.model_dump()
19792016
api_config = self._create_config(
19802017
mode="create",
19812018
agent_engine=agent_engine,
@@ -1986,6 +2023,7 @@ def create(
19862023
gcs_dir_name=config.gcs_dir_name,
19872024
extra_packages=config.extra_packages,
19882025
env_vars=config.env_vars,
2026+
context_spec=context_spec,
19892027
)
19902028
operation = self._create(config=api_config)
19912029
# TODO: Use a more specific link.
@@ -2029,6 +2067,7 @@ def _create_config(
20292067
gcs_dir_name: Optional[str] = None,
20302068
extra_packages: Optional[Sequence[str]] = None,
20312069
env_vars: Optional[dict[str, Union[str, Any]]] = None,
2070+
context_spec: Optional[dict[str, Any]] = None,
20322071
):
20332072
import sys
20342073
from vertexai.agent_engines import _agent_engines
@@ -2049,6 +2088,8 @@ def _create_config(
20492088
if description is not None:
20502089
update_masks.append("description")
20512090
config["description"] = description
2091+
if context_spec is not None:
2092+
config["context_spec"] = context_spec
20522093
if agent_engine is not None:
20532094
sys_version = f"{sys.version_info.major}.{sys.version_info.minor}"
20542095
gcs_dir_name = gcs_dir_name or _agent_engines._DEFAULT_GCS_DIR_NAME
@@ -2307,6 +2348,10 @@ def update(
23072348
"config must be a dict or AgentEngineConfig, but got"
23082349
f" {type(config)}."
23092350
)
2351+
context_spec = config.context_spec
2352+
if context_spec is not None:
2353+
# Conversion to a dict for _create_config
2354+
context_spec = context_spec.model_dump()
23102355
api_config = self._create_config(
23112356
mode="update",
23122357
agent_engine=agent_engine,
@@ -2317,6 +2362,7 @@ def update(
23172362
gcs_dir_name=config.gcs_dir_name,
23182363
extra_packages=config.extra_packages,
23192364
env_vars=config.env_vars,
2365+
context_spec=context_spec,
23202366
)
23212367
operation = self._update(name=name, config=api_config)
23222368
logger.info(

0 commit comments

Comments
 (0)