Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import bigframes.operations.aggregations as agg_ops


def const(value: typing.Hashable, dtype: dtypes.ExpressionType = None) -> Expression:
def const(
value: typing.Hashable, dtype: dtypes.ExpressionType = None
) -> ScalarConstantExpression:
return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value))


Expand Down Expand Up @@ -141,6 +143,9 @@ class ScalarConstantExpression(Expression):
def is_const(self) -> bool:
return True

def rename(self, name_mapping: Mapping[str, str]) -> ScalarConstantExpression:
return self

def output_type(
self, input_types: dict[str, bigframes.dtypes.Dtype]
) -> dtypes.ExpressionType:
Expand All @@ -167,7 +172,7 @@ class UnboundVariableExpression(Expression):
def unbound_variables(self) -> typing.Tuple[str, ...]:
return (self.id,)

def rename(self, name_mapping: Mapping[str, str]) -> Expression:
def rename(self, name_mapping: Mapping[str, str]) -> UnboundVariableExpression:
if self.id in name_mapping:
return UnboundVariableExpression(name_mapping[self.id])
else:
Expand Down
74 changes: 57 additions & 17 deletions bigframes/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import typing
from typing import List, Sequence
from typing import List, Sequence, Union

import bigframes_vendored.constants as constants
import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
Expand Down Expand Up @@ -180,9 +180,10 @@ def _apply_binary_op(
(self_col, other_col, block) = self._align(other_series, how=alignment)

name = self._name
# Drop name if both objects have name attr, but they don't match
if (
hasattr(other, "name")
and other.name != self._name
and other_series.name != self._name
and alignment == "outer"
):
name = None
Expand All @@ -208,41 +209,78 @@ def _apply_nary_op(
ignore_self=False,
):
"""Applies an n-ary operator to the series and others."""
values, block = self._align_n(others, ignore_self=ignore_self)
block, result_id = block.apply_nary_op(
values,
op,
self._name,
values, block = self._align_n(
others, ignore_self=ignore_self, cast_scalars=False
)
block, result_id = block.project_expr(op.as_expr(*values))
return series.Series(block.select_column(result_id))

def _apply_binary_aggregation(
self, other: series.Series, stat: agg_ops.BinaryAggregateOp
) -> float:
(left, right, block) = self._align(other, how="outer")
assert isinstance(left, ex.UnboundVariableExpression)
assert isinstance(right, ex.UnboundVariableExpression)
return block.get_binary_stat(left.id, right.id, stat)

AlignedExprT = Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]

return block.get_binary_stat(left, right, stat)
@typing.overload
def _align(
self, other: series.Series, how="outer"
) -> tuple[
ex.UnboundVariableExpression,
ex.UnboundVariableExpression,
blocks.Block,
]:
...

def _align(self, other: series.Series, how="outer") -> tuple[str, str, blocks.Block]: # type: ignore
@typing.overload
def _align(
self, other: typing.Union[series.Series, scalars.Scalar], how="outer"
) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]:
...

def _align(
self, other: typing.Union[series.Series, scalars.Scalar], how="outer"
) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]:
"""Aligns the series value with another scalar or series object. Returns new left column id, right column id and joined tabled expression."""
values, block = self._align_n(
[
other,
],
how,
)
return (values[0], values[1], block)
return (typing.cast(ex.UnboundVariableExpression, values[0]), values[1], block)

def _align3(self, other1: series.Series | scalars.Scalar, other2: series.Series | scalars.Scalar, how="left") -> tuple[ex.UnboundVariableExpression, AlignedExprT, AlignedExprT, blocks.Block]: # type: ignore
"""Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
values, index = self._align_n([other1, other2], how)
return (
typing.cast(ex.UnboundVariableExpression, values[0]),
values[1],
values[2],
index,
)

def _align_n(
self,
others: typing.Sequence[typing.Union[series.Series, scalars.Scalar]],
how="outer",
ignore_self=False,
) -> tuple[typing.Sequence[str], blocks.Block]:
cast_scalars: bool = True,
) -> tuple[
typing.Sequence[
Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
],
blocks.Block,
]:
if ignore_self:
value_ids: List[str] = []
value_ids: List[
Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
] = []
else:
value_ids = [self._value_column]
value_ids = [ex.free_var(self._value_column)]

block = self._block
for other in others:
Expand All @@ -252,14 +290,16 @@ def _align_n(
get_column_right,
) = block.join(other._block, how=how)
value_ids = [
*[get_column_left[value] for value in value_ids],
get_column_right[other._value_column],
*[value.rename(get_column_left) for value in value_ids],
ex.free_var(get_column_right[other._value_column]),
]
else:
# Will throw if can't interpret as scalar.
dtype = typing.cast(bigframes.dtypes.Dtype, self._dtype)
block, constant_col_id = block.create_constant(other, dtype=dtype)
value_ids = [*value_ids, constant_col_id]
value_ids = [
*value_ids,
ex.const(other, dtype=dtype if cast_scalars else None),
]
return (value_ids, block)

def _throw_if_null_index(self, opname: str):
Expand Down
33 changes: 9 additions & 24 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,23 +445,13 @@ def between(self, left, right, inclusive="both"):
)

def case_when(self, caselist) -> Series:
cases = list(itertools.chain(*caselist, (True, self)))
return self._apply_nary_op(
ops.case_when_op,
tuple(
itertools.chain(
itertools.chain(*caselist),
# Fallback to current value if no other matches.
(
# We make a Series with a constant value to avoid casts to
# types other than boolean.
Series(True, index=self.index, dtype=pandas.BooleanDtype()),
self,
),
),
),
cases,
# Self is already included in "others".
ignore_self=True,
)
).rename(self.name)

@validations.requires_ordering()
def cumsum(self) -> Series:
Expand Down Expand Up @@ -1116,8 +1106,8 @@ def ne(self, other: object) -> Series:

def where(self, cond, other=None):
value_id, cond_id, other_id, block = self._align3(cond, other)
block, result_id = block.apply_ternary_op(
value_id, cond_id, other_id, ops.where_op
block, result_id = block.project_expr(
ops.where_op.as_expr(value_id, cond_id, other_id)
)
return Series(block.select_column(result_id).with_column_labels([self.name]))

Expand All @@ -1129,8 +1119,8 @@ def clip(self, lower, upper):
if upper is None:
return self._apply_binary_op(lower, ops.maximum_op, alignment="left")
value_id, lower_id, upper_id, block = self._align3(lower, upper)
block, result_id = block.apply_ternary_op(
value_id, lower_id, upper_id, ops.clip_op
block, result_id = block.project_expr(
ops.clip_op.as_expr(value_id, lower_id, upper_id),
)
return Series(block.select_column(result_id).with_column_labels([self.name]))

Expand Down Expand Up @@ -1242,8 +1232,8 @@ def __getitem__(self, indexer):
return self.iloc[indexer]
if isinstance(indexer, Series):
(left, right, block) = self._align(indexer, "left")
block = block.filter_by_id(right)
block = block.select_column(left)
block = block.filter(right)
block = block.select_column(left.id)
return Series(block)
return self.loc[indexer]

Expand All @@ -1262,11 +1252,6 @@ def __getattr__(self, key: str):
else:
raise AttributeError(key)

def _align3(self, other1: Series | scalars.Scalar, other2: Series | scalars.Scalar, how="left") -> tuple[str, str, str, blocks.Block]: # type: ignore
"""Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
values, index = self._align_n([other1, other2], how)
return (values[0], values[1], values[2], index)

def _apply_aggregation(
self, op: agg_ops.UnaryAggregateOp | agg_ops.NullaryAggregateOp
) -> Any:
Expand Down
13 changes: 8 additions & 5 deletions tests/system/small/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,27 +2709,30 @@ def test_between(scalars_df_index, scalars_pandas_df_index, left, right, inclusi
)


def test_case_when(scalars_df_index, scalars_pandas_df_index):
def test_series_case_when(scalars_dfs_maybe_ordered):
pytest.importorskip(
"pandas",
minversion="2.2.0",
reason="case_when added in pandas 2.2.0",
)
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered

bf_series = scalars_df_index["int64_col"]
pd_series = scalars_pandas_df_index["int64_col"]
bf_series = scalars_df["int64_col"]
pd_series = scalars_pandas_df["int64_col"]

# TODO(tswast): pandas case_when appears to assume True when a value is
# null. I suspect this should be considered a bug in pandas.
bf_result = bf_series.case_when(
[
((bf_series > 100).fillna(True), 1000),
((bf_series > 100).fillna(True), bf_series - 1),
((bf_series > 0).fillna(True), pd.NA),
((bf_series < -100).fillna(True), -1000),
]
).to_pandas()
pd_result = pd_series.case_when(
[
(pd_series > 100, 1000),
(pd_series > 100, pd_series - 1),
(pd_series > 0, pd.NA),
(pd_series < -100, -1000),
]
)
Expand Down