Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,9 +1898,13 @@ def _groupby_values(
)

def apply(
self, func, by_row: typing.Union[typing.Literal["compat"], bool] = "compat"
self,
func,
by_row: typing.Union[typing.Literal["compat"], bool] = "compat",
*,
args: typing.Tuple = (),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In pandas args can be a positional argument while by_row is a keyword only argument https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.apply.html#pandas-series-apply. I would argue that we can adhere to that right now as I don't anticipate many people passing by_row positional argument, but if not, we should add an item in BigFrames 3.0 to make the breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, also thought about it. It’s difficult to fully match pandas API here so far. I have added some comments for potential breaking changes in the future. However, given that this is an edge case, I recommend we defer any such changes until there is a clear need based on user feedback.

) -> Series:
# TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs
# TODO(shobs, b/274645634): Support convert_dtype, **kwargs
# is actually a ternary op

if by_row not in ["compat", False]:
Expand Down Expand Up @@ -1944,10 +1948,16 @@ def apply(
raise

# We are working with bigquery function at this point
result_series = self._apply_unary_op(
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
)
if args:
result_series = self._apply_nary_op(
ops.NaryRemoteFunctionOp(function_def=func.udf_def), args
)
else:
result_series = self._apply_unary_op(
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
)
result_series = func._post_process_series(result_series)

return result_series

def combine(
Expand Down
33 changes: 33 additions & 0 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,3 +1111,36 @@ def _is_positive(s):
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)


def test_managed_function_series_apply_args(session, dataset_id, scalars_dfs):
try:

with pytest.warns(bfe.PreviewWarning, match="udf is in preview."):

@session.udf(dataset=dataset_id, name=prefixer.create_prefix())
def foo_list(x: int, y0: float, y1: bytes, y2: bool) -> list[str]:
return [str(x), str(y0), str(y1), str(y2)]

scalars_df, scalars_pandas_df = scalars_dfs

bf_result_col = scalars_df["int64_too"].apply(
foo_list, args=(12.34, b"hello world", False)
)
bf_result = (
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
)

pd_result_col = scalars_pandas_df["int64_too"].apply(
foo_list, args=(12.34, b"hello world", False)
)
pd_result = (
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False)
53 changes: 53 additions & 0 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,3 +2969,56 @@ def _ten_times(x):
finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_series_apply_args(session, dataset_id, scalars_dfs):
try:

@session.remote_function(
dataset=dataset_id,
reuse=False,
cloud_function_service_account="default",
)
def foo(x: int, y: bool, z: float) -> str:
if y:
return f"{x}: y is True."
if z > 0.0:
return f"{x}: y is False and z is positive."
return f"{x}: y is False and z is non-positive."

scalars_df, scalars_pandas_df = scalars_dfs

args1 = (True, 10.0)
bf_result_col = scalars_df["int64_too"].apply(foo, args=args1)
bf_result = (
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
)

pd_result_col = scalars_pandas_df["int64_too"].apply(foo, args=args1)
pd_result = (
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

args2 = (False, -10.0)
foo_ref = session.read_gbq_function(foo.bigframes_bigquery_function)

bf_result_col = scalars_df["int64_too"].apply(foo_ref, args=args2)
bf_result = (
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
)

pd_result_col = scalars_pandas_df["int64_too"].apply(foo, args=args2)
pd_result = (
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)