Skip to content

Commit 6e92dd6

Browse files
committed
supports dataframegroupby and adds tests/docs
1 parent eec79d1 commit 6e92dd6

File tree

2 files changed

+127
-6
lines changed

2 files changed

+127
-6
lines changed

β€Žbigframes/bigquery/__init__.pyβ€Ž

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222

2323
import typing
2424

25+
import bigframes.constants as constants
2526
import bigframes.core.groupby as groupby
2627
import bigframes.operations as ops
2728
import bigframes.operations.aggregations as agg_ops
2829

2930
if typing.TYPE_CHECKING:
31+
import bigframes.dataframe as dataframe
3032
import bigframes.series as series
3133

3234

@@ -54,6 +56,10 @@ def array_length(series: series.Series) -> series.Series:
5456
2 2
5557
dtype: Int64
5658
59+
Args:
60+
series (bigframes.series.Series):
61+
A Series with array columns.
62+
5763
Returns:
5864
bigframes.series.Series: A Series of integer values indicating
5965
the length of each element in the Series.
@@ -62,5 +68,53 @@ def array_length(series: series.Series) -> series.Series:
6268
return series._apply_unary_op(ops.len_op)
6369

6470

65-
def array_agg(groupby_series: groupby.SeriesGroupBy) -> series.Series:
66-
return groupby_series._aggregate(agg_ops.ArrayAggOp())
71+
def array_agg(
72+
obj: groupby.SeriesGroupBy | groupby.DataFrameGroupBy,
73+
) -> series.Series | dataframe.DataFrame:
74+
"""Group data and create arrays from selected columns, omitting NULLs to avoid
75+
BigQuery errors (NULLs not allowed in arrays).
76+
77+
**Examples:**
78+
79+
>>> import bigframes.pandas as bpd
80+
>>> import bigframes.bigquery as bbq
81+
>>> bpd.options.display.progress_bar = None
82+
83+
For a SeriesGroupBy object:
84+
85+
>>> lst = ['a', 'a', 'b', 'b', 'a']
86+
>>> s = bpd.Series([1, 2, 3, 4, np.nan], index=lst)
87+
>>> bbq.array_agg(s.groupby(level=0))
88+
a [1. 2.]
89+
b [3. 4.]
90+
dtype: list<item: double>[pyarrow]
91+
92+
For a DataFrameGroupBy object:
93+
94+
>>> l = [[1, 2, 3], [1, None, 4], [2, 1, 3], [1, 2, 2]]
95+
>>> df = bpd.DataFrame(l, columns=["a", "b", "c"])
96+
>>> bbq.array_agg(df.groupby(by=["b"]))
97+
b a c
98+
1.0 [2] [3]
99+
2.0 [1 1] [3 2]
100+
2 rows Γ— 2 columns
101+
102+
[2 rows x 2 columns in total]
103+
104+
Args:
105+
obj (groupby.SeriesGroupBy | groupby.DataFrameGroupBy):
106+
A GroupBy object to be applied the function.
107+
108+
Returns:
109+
bigframes.series.Series | bigframes.dataframe.DataFrame: A Series or
110+
DataFrame containing aggregated array columns, and indexed by the
111+
original group columns.
112+
"""
113+
if isinstance(obj, groupby.SeriesGroupBy):
114+
return obj._aggregate(agg_ops.ArrayAggOp())
115+
elif isinstance(obj, groupby.DataFrameGroupBy):
116+
return obj._aggregate_all(agg_ops.ArrayAggOp(), numeric_only=False)
117+
else:
118+
raise ValueError(
119+
f"Unsupported type {type(obj)} to apply `array_agg` function. {constants.FEEDBACK_LINK}"
120+
)

β€Žtests/system/small/bigquery/test_array.pyβ€Ž

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
import pytest
1718

1819
import bigframes.bigquery as bbq
1920
import bigframes.pandas as bpd
@@ -23,10 +24,76 @@ def test_array_length():
2324
series = bpd.Series([["A", "AA", "AAA"], ["BB", "B"], np.nan, [], ["C"]])
2425
# TODO(b/336880368): Allow for NULL values to be input for ARRAY columns.
2526
# Once we actually store NULL values, this will be NULL where the input is NULL.
26-
expected = pd.Series([3, 2, 0, 0, 1])
27+
expected = bpd.Series([3, 2, 0, 0, 1])
2728
pd.testing.assert_series_equal(
2829
bbq.array_length(series).to_pandas(),
29-
expected,
30-
check_dtype=False,
31-
check_index_type=False,
30+
expected.to_pandas(),
31+
)
32+
33+
34+
@pytest.mark.parametrize(
35+
("input_data", "output_data"),
36+
[
37+
pytest.param([1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]], id="ints"),
38+
pytest.param(
39+
["e", "d", "c", "b", "a"],
40+
[["e", "d"], ["c", "b"], ["a"]],
41+
id="reverse_strings",
42+
),
43+
pytest.param(
44+
[1.0, 2.0, np.nan, np.nan, np.nan], [[1.0, 2.0], [], []], id="nans"
45+
),
46+
pytest.param(
47+
[{"A": {"x": 1.0}}, {"A": {"z": 4.0}}, {}, {"B": "b"}, np.nan],
48+
[[{"A": {"x": 1.0}}, {"A": {"z": 4.0}}], [{}, {"B": "b"}], []],
49+
id="structs",
50+
),
51+
],
52+
)
53+
def test_array_agg_w_series(input_data, output_data):
54+
input_index = ["a", "a", "b", "b", "c"]
55+
series = bpd.Series(input_data, index=input_index)
56+
result = bbq.array_agg(series.groupby(level=0))
57+
58+
expected = bpd.Series(output_data, index=["a", "b", "c"])
59+
pd.testing.assert_series_equal(
60+
result.to_pandas(),
61+
expected.to_pandas(),
62+
)
63+
64+
65+
def test_array_agg_w_dataframe():
66+
data = {
67+
"a": [1, 1, 2, 1],
68+
"b": [2, None, 1, 2],
69+
"c": [3, 4, 3, 2],
70+
}
71+
df = bpd.DataFrame(data)
72+
result = bbq.array_agg(df.groupby(by=["b"]))
73+
74+
expected_data = {
75+
"b": [1.0, 2.0],
76+
"a": [[2], [1, 1]],
77+
"c": [[3], [3, 2]],
78+
}
79+
expected = bpd.DataFrame(expected_data).set_index("b")
80+
81+
pd.testing.assert_frame_equal(
82+
result.to_pandas(),
83+
expected.to_pandas(),
84+
)
85+
86+
def assert_array_agg_matches_after_explode():
87+
data = {
88+
"index": np.arange(10),
89+
"a": [np.random.randint(0, 10, 10) for _ in range(10)],
90+
"b": [np.random.randint(0, 10, 10) for _ in range(10)],
91+
}
92+
df = bpd.DataFrame(data).set_index("index")
93+
result = bbq.array_agg(df.explode(["a", "b"]).groupby(level=0))
94+
result.index.name = "index"
95+
96+
pd.testing.assert_frame_equal(
97+
result.to_pandas(),
98+
df.to_pandas(),
3299
)

0 commit comments

Comments
 (0)