15
15
import sys
16
16
17
17
import pandas as pd
18
- import pandas . testing
18
+ import pyarrow as pa
19
19
import pytest
20
20
21
+ from bigframes import series
21
22
import bigframes .bigquery as bbq
22
23
import bigframes .pandas as bpd
23
24
@@ -27,15 +28,17 @@ def test_ai_generate_bool(session):
27
28
s2 = bpd .Series (["fruit" , "tree" ], session = session )
28
29
prompt = (s1 , " is a " , s2 )
29
30
30
- result = bbq .ai .generate_bool (prompt , endpoint = "gemini-2.5-flash" ).struct .field (
31
- "result"
32
- )
31
+ result = bbq .ai .generate_bool (prompt , endpoint = "gemini-2.5-flash" )
33
32
34
- pandas .testing .assert_series_equal (
35
- result .to_pandas (),
36
- pd .Series ([True , False ], name = "result" ),
37
- check_dtype = False ,
38
- check_index = False ,
33
+ assert _contains_no_nulls (result )
34
+ assert result .dtype == pd .ArrowDtype (
35
+ pa .struct (
36
+ (
37
+ pa .field ("result" , pa .bool_ ()),
38
+ pa .field ("full_response" , pa .string ()),
39
+ pa .field ("status" , pa .string ()),
40
+ )
41
+ )
39
42
)
40
43
41
44
@@ -52,11 +55,38 @@ def test_ai_generate_bool_with_model_params(session):
52
55
53
56
result = bbq .ai .generate_bool (
54
57
prompt , endpoint = "gemini-2.5-flash" , model_params = model_params
55
- ).struct .field ("result" )
58
+ )
59
+
60
+ assert _contains_no_nulls (result )
61
+ assert result .dtype == pd .ArrowDtype (
62
+ pa .struct (
63
+ (
64
+ pa .field ("result" , pa .bool_ ()),
65
+ pa .field ("full_response" , pa .string ()),
66
+ pa .field ("status" , pa .string ()),
67
+ )
68
+ )
69
+ )
70
+
56
71
57
- pandas .testing .assert_series_equal (
58
- result .to_pandas (),
59
- pd .Series ([True , False ], name = "result" ),
60
- check_dtype = False ,
61
- check_index = False ,
72
+ def test_ai_generate_bool_multi_model (session ):
73
+ df = session .from_glob_path (
74
+ "gs://bigframes-dev-testing/a_multimodel/images/*" , name = "image"
62
75
)
76
+
77
+ result = bbq .ai .generate_bool ((df ["image" ], " contains an animal" ))
78
+
79
+ assert _contains_no_nulls (result )
80
+ assert result .dtype == pd .ArrowDtype (
81
+ pa .struct (
82
+ (
83
+ pa .field ("result" , pa .bool_ ()),
84
+ pa .field ("full_response" , pa .string ()),
85
+ pa .field ("status" , pa .string ()),
86
+ )
87
+ )
88
+ )
89
+
90
+
91
+ def _contains_no_nulls (s : series .Series ) -> bool :
92
+ return len (s ) == s .count ()
0 commit comments