# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers that prepare data for other estimators. This module is styled after
scikit-learn's preprocessing module: https://scikit-learn.org/stable/modules/preprocessing.html."""
from __future__ import annotations
import typing
from typing import Iterable, List, Literal, Optional, Union
import bigframes_vendored.sklearn.preprocessing._data
import bigframes_vendored.sklearn.preprocessing._discretization
import bigframes_vendored.sklearn.preprocessing._encoder
import bigframes_vendored.sklearn.preprocessing._label
import bigframes_vendored.sklearn.preprocessing._polynomial
import bigframes.core.utils as core_utils
import bigframes.pandas as bpd
from bigframes.core.logging import log_adapter
from bigframes.ml import base, core, globals, utils
[docs]
@log_adapter.class_logger
class StandardScaler(
base.Transformer,
bigframes_vendored.sklearn.preprocessing._data.StandardScaler,
):
__doc__ = bigframes_vendored.sklearn.preprocessing._data.StandardScaler.__doc__
def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._base_sql_generator = globals.base_sql_generator()
def _keys(self):
return (self._bqml_model,)
def _compile_to_sql(
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
) -> List[str]:
"""Compile this transformer to a list of SQL expressions that can be included in
a BQML TRANSFORM clause
Args:
X: DataFrame to transform.
columns: transform columns. If None, transform all columns in X.
Returns: a list of tuples sql_expr."""
if columns is None:
columns = X.columns
columns, _ = core_utils.get_standardized_ids(columns)
return [
self._base_sql_generator.ml_standard_scaler(
column, f"standard_scaled_{column}"
)
for column in columns
]
@classmethod
def _parse_from_sql(cls, sql: str) -> tuple[StandardScaler, str]:
"""Parse SQL to tuple(StandardScaler, column_label).
Args:
sql: SQL string of format "ML.STANDARD_SCALER({col_label}) OVER()"
Returns:
tuple(StandardScaler, column_label)"""
col_label = sql[sql.find("(") + 1 : sql.find(")")]
return cls(), _unescape_id(col_label)
[docs]
def fit(
self,
X: utils.ArrayType,
y=None, # ignored
) -> StandardScaler:
(X,) = utils.batch_convert_to_dataframe(X)
transform_sqls = self._compile_to_sql(X)
self._bqml_model = self._bqml_model_factory.create_model(
X,
options={"model_type": "transform_only"},
transforms=transform_sqls,
)
self._extract_output_names()
return self
[docs]
@log_adapter.class_logger
class MaxAbsScaler(
base.Transformer,
bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler,
):
__doc__ = bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler.__doc__
def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._base_sql_generator = globals.base_sql_generator()
def _keys(self):
return (self._bqml_model,)
def _compile_to_sql(
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
) -> List[str]:
"""Compile this transformer to a list of SQL expressions that can be included in
a BQML TRANSFORM clause
Args:
X: DataFrame to transform.
columns: transform columns. If None, transform all columns in X.
Returns: a list of tuples sql_expr."""
if columns is None:
columns = X.columns
columns, _ = core_utils.get_standardized_ids(columns)
return [
self._base_sql_generator.ml_max_abs_scaler(
column, f"max_abs_scaled_{column}"
)
for column in columns
]
@classmethod
def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]:
"""Parse SQL to tuple(MaxAbsScaler, column_label).
Args:
sql: SQL string of format "ML.MAX_ABS_SCALER({col_label}) OVER()"
Returns:
tuple(MaxAbsScaler, column_label)"""
# TODO: Use real sql parser
col_label =