Skip to content

Commit a3de53f

Browse files
authored
feat: support pandas series in ai.generate_bool (#2086)
* feat: support pandas series in ai.generate_bool * fix mypy error * define PROMPT_TYPE with Union * fix type * update test * update comment * fix mypy * fix return type * update doc * fix doctest
1 parent bbd95e5 commit a3de53f

File tree

3 files changed

+63
-20
lines changed

3 files changed

+63
-20
lines changed

β€Žbigframes/bigquery/_operations/ai.pyβ€Ž

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,25 @@
1919
from __future__ import annotations
2020

2121
import json
22-
from typing import Any, List, Literal, Mapping, Tuple
22+
from typing import Any, List, Literal, Mapping, Tuple, Union
2323

24-
from bigframes import clients, dtypes, series
25-
from bigframes.core import log_adapter
24+
import pandas as pd
25+
26+
from bigframes import clients, dtypes, series, session
27+
from bigframes.core import convert, log_adapter
2628
from bigframes.operations import ai_ops
2729

30+
PROMPT_TYPE = Union[
31+
series.Series,
32+
pd.Series,
33+
List[Union[str, series.Series, pd.Series]],
34+
Tuple[Union[str, series.Series, pd.Series], ...],
35+
]
36+
2837

2938
@log_adapter.method_logger(custom_base_name="bigquery_ai")
3039
def generate_bool(
31-
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
40+
prompt: PROMPT_TYPE,
3241
*,
3342
connection_id: str | None = None,
3443
endpoint: str | None = None,
@@ -51,7 +60,7 @@ def generate_bool(
5160
0 {'result': True, 'full_response': '{"candidate...
5261
1 {'result': True, 'full_response': '{"candidate...
5362
2 {'result': False, 'full_response': '{"candidat...
54-
dtype: struct<result: bool, full_response: string, status: string>[pyarrow]
63+
dtype: struct<result: bool, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
5564
5665
>>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
5766
0 True
@@ -60,8 +69,9 @@ def generate_bool(
6069
Name: result, dtype: boolean
6170
6271
Args:
63-
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]):
64-
A mixture of Series and string literals that specifies the prompt to send to the model.
72+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
73+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
74+
or pandas Series.
6575
connection_id (str, optional):
6676
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
6777
If not provided, the connection from the current session will be used.
@@ -84,7 +94,7 @@ def generate_bool(
8494
Returns:
8595
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
8696
* "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
87-
* "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model.
97+
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
8898
The generated text is in the text element.
8999
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
90100
"""
@@ -104,7 +114,7 @@ def generate_bool(
104114

105115

106116
def _separate_context_and_series(
107-
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
117+
prompt: PROMPT_TYPE,
108118
) -> Tuple[List[str | None], List[series.Series]]:
109119
"""
110120
Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
@@ -123,18 +133,19 @@ def _separate_context_and_series(
123133
return [None], [prompt]
124134

125135
prompt_context: List[str | None] = []
126-
series_list: List[series.Series] = []
136+
series_list: List[series.Series | pd.Series] = []
127137

138+
session = None
128139
for item in prompt:
129140
if isinstance(item, str):
130141
prompt_context.append(item)
131142

132-
elif isinstance(item, series.Series):
143+
elif isinstance(item, (series.Series, pd.Series)):
133144
prompt_context.append(None)
134145

135-
if item.dtype == dtypes.OBJ_REF_DTYPE:
136-
# Multi-model support
137-
item = item.blob.read_url()
146+
if isinstance(item, series.Series) and session is None:
147+
# Use the first available BF session if there's any.
148+
session = item._session
138149
series_list.append(item)
139150

140151
else:
@@ -143,7 +154,20 @@ def _separate_context_and_series(
143154
if not series_list:
144155
raise ValueError("Please provide at least one Series in the prompt")
145156

146-
return prompt_context, series_list
157+
converted_list = [_convert_series(s, session) for s in series_list]
158+
159+
return prompt_context, converted_list
160+
161+
162+
def _convert_series(
163+
s: series.Series | pd.Series, session: session.Session | None
164+
) -> series.Series:
165+
result = convert.to_bf_series(s, default_index=None, session=session)
166+
167+
if result.dtype == dtypes.OBJ_REF_DTYPE:
168+
# Support multimodel
169+
return result.blob.read_url()
170+
return result
147171

148172

149173
def _resolve_connection_id(series: series.Series, connection_id: str | None):

β€Žbigframes/operations/ai_ops.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
4040
pa.struct(
4141
(
4242
pa.field("result", pa.bool_()),
43-
pa.field("full_response", pa.string()),
43+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
4444
pa.field("status", pa.string()),
4545
)
4646
)

β€Žtests/system/small/bigquery/test_ai.pyβ€Ž

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pyarrow as pa
1919
import pytest
2020

21-
from bigframes import series
21+
from bigframes import dtypes, series
2222
import bigframes.bigquery as bbq
2323
import bigframes.pandas as bpd
2424

@@ -35,7 +35,26 @@ def test_ai_generate_bool(session):
3535
pa.struct(
3636
(
3737
pa.field("result", pa.bool_()),
38-
pa.field("full_response", pa.string()),
38+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
39+
pa.field("status", pa.string()),
40+
)
41+
)
42+
)
43+
44+
45+
def test_ai_generate_bool_with_pandas(session):
46+
s1 = pd.Series(["apple", "bear"])
47+
s2 = bpd.Series(["fruit", "tree"], session=session)
48+
prompt = (s1, " is a ", s2)
49+
50+
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")
51+
52+
assert _contains_no_nulls(result)
53+
assert result.dtype == pd.ArrowDtype(
54+
pa.struct(
55+
(
56+
pa.field("result", pa.bool_()),
57+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
3958
pa.field("status", pa.string()),
4059
)
4160
)
@@ -62,7 +81,7 @@ def test_ai_generate_bool_with_model_params(session):
6281
pa.struct(
6382
(
6483
pa.field("result", pa.bool_()),
65-
pa.field("full_response", pa.string()),
84+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
6685
pa.field("status", pa.string()),
6786
)
6887
)
@@ -81,7 +100,7 @@ def test_ai_generate_bool_multi_model(session):
81100
pa.struct(
82101
(
83102
pa.field("result", pa.bool_()),
84-
pa.field("full_response", pa.string()),
103+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
85104
pa.field("status", pa.string()),
86105
)
87106
)

0 commit comments

Comments
 (0)