@@ -963,3 +963,115 @@ def float_parser(row):
963
963
cleanup_function_assets (
964
964
float_parser_mf , session .bqclient , ignore_failures = False
965
965
)
966
+
967
+
968
+ def test_managed_function_df_where (session , dataset_id , scalars_dfs ):
969
+ try :
970
+
971
+ # The return type has to be bool type for callable where condition.
972
+ def is_sum_positive (a , b ):
973
+ return a + b > 0
974
+
975
+ is_sum_positive_mf = session .udf (
976
+ input_types = [int , int ],
977
+ output_type = bool ,
978
+ dataset = dataset_id ,
979
+ name = prefixer .create_prefix (),
980
+ )(is_sum_positive )
981
+
982
+ scalars_df , scalars_pandas_df = scalars_dfs
983
+ int64_cols = ["int64_col" , "int64_too" ]
984
+
985
+ bf_int64_df = scalars_df [int64_cols ]
986
+ bf_int64_df_filtered = bf_int64_df .dropna ()
987
+ pd_int64_df = scalars_pandas_df [int64_cols ]
988
+ pd_int64_df_filtered = pd_int64_df .dropna ()
989
+
990
+ # Use callable condition in dataframe.where method.
991
+ bf_result = bf_int64_df_filtered .where (is_sum_positive_mf ).to_pandas ()
992
+ # Pandas doesn't support such case, use following as workaround.
993
+ pd_result = pd_int64_df_filtered .where (pd_int64_df_filtered .sum (axis = 1 ) > 0 )
994
+
995
+ # Ignore any dtype difference.
996
+ pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
997
+
998
+ # Make sure the read_gbq_function path works for this function.
999
+ is_sum_positive_ref = session .read_gbq_function (
1000
+ function_name = is_sum_positive_mf .bigframes_bigquery_function
1001
+ )
1002
+
1003
+ bf_result_gbq = bf_int64_df_filtered .where (
1004
+ is_sum_positive_ref , - bf_int64_df_filtered
1005
+ ).to_pandas ()
1006
+ pd_result_gbq = pd_int64_df_filtered .where (
1007
+ pd_int64_df_filtered .sum (axis = 1 ) > 0 , - pd_int64_df_filtered
1008
+ )
1009
+
1010
+ # Ignore any dtype difference.
1011
+ pandas .testing .assert_frame_equal (
1012
+ bf_result_gbq , pd_result_gbq , check_dtype = False
1013
+ )
1014
+
1015
+ finally :
1016
+ # Clean up the gcp assets created for the managed function.
1017
+ cleanup_function_assets (
1018
+ is_sum_positive_mf , session .bqclient , ignore_failures = False
1019
+ )
1020
+
1021
+
1022
+ def test_managed_function_df_where_series (session , dataset_id , scalars_dfs ):
1023
+ try :
1024
+
1025
+ # The return type has to be bool type for callable where condition.
1026
+ def is_sum_positive_series (s ):
1027
+ return s ["int64_col" ] + s ["int64_too" ] > 0
1028
+
1029
+ is_sum_positive_series_mf = session .udf (
1030
+ input_types = bigframes .series .Series ,
1031
+ output_type = bool ,
1032
+ dataset = dataset_id ,
1033
+ name = prefixer .create_prefix (),
1034
+ )(is_sum_positive_series )
1035
+
1036
+ scalars_df , scalars_pandas_df = scalars_dfs
1037
+ int64_cols = ["int64_col" , "int64_too" ]
1038
+
1039
+ bf_int64_df = scalars_df [int64_cols ]
1040
+ bf_int64_df_filtered = bf_int64_df .dropna ()
1041
+ pd_int64_df = scalars_pandas_df [int64_cols ]
1042
+ pd_int64_df_filtered = pd_int64_df .dropna ()
1043
+
1044
+ # Use callable condition in dataframe.where method.
1045
+ bf_result = bf_int64_df_filtered .where (is_sum_positive_series ).to_pandas ()
1046
+ pd_result = pd_int64_df_filtered .where (is_sum_positive_series )
1047
+
1048
+ # Ignore any dtype difference.
1049
+ pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
1050
+
1051
+ # Make sure the read_gbq_function path works for this function.
1052
+ is_sum_positive_series_ref = session .read_gbq_function (
1053
+ function_name = is_sum_positive_series_mf .bigframes_bigquery_function ,
1054
+ is_row_processor = True ,
1055
+ )
1056
+
1057
+ # This is for callable `other` arg in dataframe.where method.
1058
+ def func_for_other (x ):
1059
+ return - x
1060
+
1061
+ bf_result_gbq = bf_int64_df_filtered .where (
1062
+ is_sum_positive_series_ref , func_for_other
1063
+ ).to_pandas ()
1064
+ pd_result_gbq = pd_int64_df_filtered .where (
1065
+ is_sum_positive_series , func_for_other
1066
+ )
1067
+
1068
+ # Ignore any dtype difference.
1069
+ pandas .testing .assert_frame_equal (
1070
+ bf_result_gbq , pd_result_gbq , check_dtype = False
1071
+ )
1072
+
1073
+ finally :
1074
+ # Clean up the gcp assets created for the managed function.
1075
+ cleanup_function_assets (
1076
+ is_sum_positive_series_mf , session .bqclient , ignore_failures = False
1077
+ )
0 commit comments