Skip to content

Commit 566a37a

Browse files
authored
fix: deflake ai_gen_bool multimodel test (#2085)
* fix: deflake ai_gen_bool multimodel test * fix lint * fix doctest too * consolidates tests under system/small * fix doctest
1 parent c1e871d commit 566a37a

File tree

4 files changed

+47
-82
lines changed

4 files changed

+47
-82
lines changed

β€Žbigframes/bigquery/_operations/ai.pyβ€Ž

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,13 @@ def generate_bool(
4747
... "col_1": ["apple", "bear", "pear"],
4848
... "col_2": ["fruit", "animal", "animal"]
4949
... })
50-
>>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"]))
50+
>>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"]))
5151
0 {'result': True, 'full_response': '{"candidate...
5252
1 {'result': True, 'full_response': '{"candidate...
5353
2 {'result': False, 'full_response': '{"candidat...
5454
dtype: struct<result: bool, full_response: string, status: string>[pyarrow]
5555
56-
>>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
57-
0 True
58-
1 True
59-
2 False
60-
Name: result, dtype: boolean
61-
62-
>>> model_params = {
63-
... "generation_config": {
64-
... "thinking_config": {
65-
... "thinking_budget": 0
66-
... }
67-
... }
68-
... }
69-
>>> bbq.ai_generate_bool(
70-
... (df["col_1"], " is a ", df["col_2"]),
71-
... endpoint="gemini-2.5-pro",
72-
... model_params=model_params,
73-
... ).struct.field("result")
56+
>>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
7457
0 True
7558
1 True
7659
2 False

β€Žtests/system/large/bigquery/__init__.pyβ€Ž

Lines changed: 0 additions & 13 deletions
This file was deleted.

β€Žtests/system/large/bigquery/test_ai.pyβ€Ž

Lines changed: 0 additions & 35 deletions
This file was deleted.

β€Žtests/system/small/bigquery/test_ai.pyβ€Ž

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
import sys
1616

1717
import pandas as pd
18-
import pandas.testing
18+
import pyarrow as pa
1919
import pytest
2020

21+
from bigframes import series
2122
import bigframes.bigquery as bbq
2223
import bigframes.pandas as bpd
2324

@@ -27,15 +28,17 @@ def test_ai_generate_bool(session):
2728
s2 = bpd.Series(["fruit", "tree"], session=session)
2829
prompt = (s1, " is a ", s2)
2930

30-
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash").struct.field(
31-
"result"
32-
)
31+
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")
3332

34-
pandas.testing.assert_series_equal(
35-
result.to_pandas(),
36-
pd.Series([True, False], name="result"),
37-
check_dtype=False,
38-
check_index=False,
33+
assert _contains_no_nulls(result)
34+
assert result.dtype == pd.ArrowDtype(
35+
pa.struct(
36+
(
37+
pa.field("result", pa.bool_()),
38+
pa.field("full_response", pa.string()),
39+
pa.field("status", pa.string()),
40+
)
41+
)
3942
)
4043

4144

@@ -52,11 +55,38 @@ def test_ai_generate_bool_with_model_params(session):
5255

5356
result = bbq.ai.generate_bool(
5457
prompt, endpoint="gemini-2.5-flash", model_params=model_params
55-
).struct.field("result")
58+
)
59+
60+
assert _contains_no_nulls(result)
61+
assert result.dtype == pd.ArrowDtype(
62+
pa.struct(
63+
(
64+
pa.field("result", pa.bool_()),
65+
pa.field("full_response", pa.string()),
66+
pa.field("status", pa.string()),
67+
)
68+
)
69+
)
70+
5671

57-
pandas.testing.assert_series_equal(
58-
result.to_pandas(),
59-
pd.Series([True, False], name="result"),
60-
check_dtype=False,
61-
check_index=False,
72+
def test_ai_generate_bool_multi_model(session):
73+
df = session.from_glob_path(
74+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
6275
)
76+
77+
result = bbq.ai.generate_bool((df["image"], " contains an animal"))
78+
79+
assert _contains_no_nulls(result)
80+
assert result.dtype == pd.ArrowDtype(
81+
pa.struct(
82+
(
83+
pa.field("result", pa.bool_()),
84+
pa.field("full_response", pa.string()),
85+
pa.field("status", pa.string()),
86+
)
87+
)
88+
)
89+
90+
91+
def _contains_no_nulls(s: series.Series) -> bool:
92+
return len(s) == s.count()

0 commit comments

Comments
 (0)