Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Apply multiple StandardScaler's to individual groups?

Is there a pythonic way to chain together sklearn's StandardScaler instances to independently scale data with groups? I.e., if I wanted to find independently scale the features of the iris dataset; I could use the following code:

from sklearn.datasets import load_iris
data = load_iris()
df = pd.DataFrame(data['data'], columns=data['feature_names'])
df['class'] = data['target']

means = df.groupby('class').mean()
stds = df.groupby('class').std()

df_rescaled = (
    (df.drop(['class'], 1) - means.reindex(df['class']).values) / 
     stds.reindex(df['class']).values)

Here, I'm subtracting by the mean and dividing by the stdev of each group independently. But Its somewhat hard to carry around these means and stdev's, and essentially, replicate the behavior of StandardScaler when I have a categorical variable I'd like to control for.

Is there a more pythonic / sklearn-friendly way to implement this type of scaling?

like image 525
pstjohn Avatar asked Apr 09 '19 21:04

pstjohn


1 Answers

Sure, you can use any sklearn operation and apply it to a groupby object.

First, a little convenience wrapper:

import typing
import pandas as pd

class SklearnWrapper:
    def __init__(self, transform: typing.Callable):
        self.transform = transform

    def __call__(self, df):
        transformed = self.transform.fit_transform(df.values)
        return pd.DataFrame(transformed, columns=df.columns, index=df.index)

This one will apply any sklearn transform you pass into it to a group.

And finally simple usage:

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler

data = load_iris()
df = pd.DataFrame(data["data"], columns=data["feature_names"])
df["class"] = data["target"]

df_rescaled = (
    df.groupby("class")
    .apply(SklearnWrapper(StandardScaler()))
    .drop("class", axis="columns")
)

EDIT: You can pretty much do anything with SklearnWrapper. Here is an example of transforming and reversing this operation for each group (e.g. do not overwrite the transformation object) - just fit the object anew each time a new group is seen (and add it to list).

I have kinda replicated a bit of sklearn's functionality for easier usage (you can extend it with any function you want by passing appropriate string to _call_with_function internal method):

class SklearnWrapper:
    def __init__(self, transformation: typing.Callable):
        self.transformation = transformation
        self._group_transforms = []
        # Start with -1 and for each group up the pointer by one
        self._pointer = -1

    def _call_with_function(self, df: pd.DataFrame, function: str):
        # If pointer >= len we are making a new apply, reset _pointer
        if self._pointer >= len(self._group_transforms):
            self._pointer = -1
        self._pointer += 1
        return pd.DataFrame(
            getattr(self._group_transforms[self._pointer], function)(df.values),
            columns=df.columns,
            index=df.index,
        )

    def fit(self, df):
        self._group_transforms.append(self.transformation.fit(df.values))
        return self

    def transform(self, df):
        return self._call_with_function(df, "transform")

    def fit_transform(self, df):
        self.fit(df)
        return self.transform(df)

    def inverse_transform(self, df):
        return self._call_with_function(df, "inverse_transform")

Usage (group transform, inverse operation and apply it again):

data = load_iris()
df = pd.DataFrame(data["data"], columns=data["feature_names"])
df["class"] = data["target"]

# Create scaler outside the class
scaler = SklearnWrapper(StandardScaler())

# Fit and transform data (holding state)
df_rescaled = df.groupby("class").apply(scaler.fit_transform)

# Inverse the operation
df_inverted = df_rescaled.groupby("class").apply(scaler.inverse_transform)

# Apply transformation once again
df_transformed = (
    df_inverted.groupby("class")
    .apply(scaler.transform)
    .drop("class", axis="columns")
)
like image 152
Szymon Maszke Avatar answered Oct 06 '22 00:10

Szymon Maszke