Source code for bigframes.ml.preprocessing

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformers that prepare data for other estimators. This module is styled after
scikit-learn's preprocessing module: https://scikit-learn.org/stable/modules/preprocessing.html."""

from __future__ import annotations

import typing
from typing import Iterable, List, Literal, Optional, Union

import bigframes_vendored.sklearn.preprocessing._data
import bigframes_vendored.sklearn.preprocessing._discretization
import bigframes_vendored.sklearn.preprocessing._encoder
import bigframes_vendored.sklearn.preprocessing._label
import bigframes_vendored.sklearn.preprocessing._polynomial

import bigframes.core.utils as core_utils
import bigframes.pandas as bpd
from bigframes.core.logging import log_adapter
from bigframes.ml import base, core, globals, utils


[docs] @log_adapter.class_logger class StandardScaler( base.Transformer, bigframes_vendored.sklearn.preprocessing._data.StandardScaler, ): __doc__ = bigframes_vendored.sklearn.preprocessing._data.StandardScaler.__doc__ def __init__(self): self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() self._base_sql_generator = globals.base_sql_generator() def _keys(self): return (self._bqml_model,) def _compile_to_sql( self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None ) -> List[str]: """Compile this transformer to a list of SQL expressions that can be included in a BQML TRANSFORM clause Args: X: DataFrame to transform. columns: transform columns. If None, transform all columns in X. Returns: a list of tuples sql_expr.""" if columns is None: columns = X.columns columns, _ = core_utils.get_standardized_ids(columns) return [ self._base_sql_generator.ml_standard_scaler( column, f"standard_scaled_{column}" ) for column in columns ] @classmethod def _parse_from_sql(cls, sql: str) -> tuple[StandardScaler, str]: """Parse SQL to tuple(StandardScaler, column_label). Args: sql: SQL string of format "ML.STANDARD_SCALER({col_label}) OVER()" Returns: tuple(StandardScaler, column_label)""" col_label = sql[sql.find("(") + 1 : sql.find(")")] return cls(), _unescape_id(col_label)
[docs] def fit( self, X: utils.ArrayType, y=None, # ignored ) -> StandardScaler: (X,) = utils.batch_convert_to_dataframe(X) transform_sqls = self._compile_to_sql(X) self._bqml_model = self._bqml_model_factory.create_model( X, options={"model_type": "transform_only"}, transforms=transform_sqls, ) self._extract_output_names() return self
[docs] def transform(self, X: utils.ArrayType) -> bpd.DataFrame: if not self._bqml_model: raise RuntimeError("Must be fitted before transform") (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) df = self._bqml_model.transform(X) return typing.cast( bpd.DataFrame, df[self._output_names], )
[docs] @log_adapter.class_logger class MaxAbsScaler( base.Transformer, bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler, ): __doc__ = bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler.__doc__ def __init__(self): self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() self._base_sql_generator = globals.base_sql_generator() def _keys(self): return (self._bqml_model,) def _compile_to_sql( self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None ) -> List[str]: """Compile this transformer to a list of SQL expressions that can be included in a BQML TRANSFORM clause Args: X: DataFrame to transform. columns: transform columns. If None, transform all columns in X. Returns: a list of tuples sql_expr.""" if columns is None: columns = X.columns columns, _ = core_utils.get_standardized_ids(columns) return [ self._base_sql_generator.ml_max_abs_scaler( column, f"max_abs_scaled_{column}" ) for column in columns ] @classmethod def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]: """Parse SQL to tuple(MaxAbsScaler, column_label). Args: sql: SQL string of format "ML.MAX_ABS_SCALER({col_label}) OVER()" Returns: tuple(MaxAbsScaler, column_label)""" # TODO: Use real sql parser col_label =