I have a data set like the following and want to scale the data using any of the scalers in sklearn.preprocessing
.
Is there an easy way to fit this scaler not over the whole data set, but per group? My current solution can't be included in a Pipeline:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
df = pd.DataFrame({'group': [1, 1, 1, 2, 2, 2], 'x': [1,2,3,10,20,30]})
def scale(x):
# see https://stackoverflow.com/a/72408669/3104974
scaler = StandardScaler()
return scaler.fit_transform(x.values[:,np.newaxis]).ravel()
df['x_scaled'] = df.groupby('group').transform(scale)
group x x_scaled
0 1 1 -1.224745
1 1 2 0.000000
2 1 3 1.224745
3 2 10 -1.224745
4 2 20 0.000000
5 2 30 1.224745
You can create custom transformer, using BaseEstimator
and TransformerMixin
for example:
import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler
class GroupScaler(BaseEstimator, TransformerMixin):
def __init__(self, group_column, scaler=None):
self.group_column = group_column
self.scaler = scaler or StandardScaler()
self.scalers_ = {}
def fit(self, X, y=None):
self.scalers_ = {}
for group, group_data in X.groupby(self.group_column):
scaler = clone(self.scaler)
scaler.fit(group_data.drop(columns=[self.group_column]))
self.scalers_[group] = scaler
return self
def transform(self, X):
X_scaled = []
for group, group_data in X.groupby(self.group_column):
scaler = self.scalers_[group]
scaled = scaler.transform(group_data.drop(columns=[self.group_column]))
group_df = pd.DataFrame(scaled, index=group_data.index, columns=group_data.columns.drop(self.group_column))
group_df[self.group_column] = group
X_scaled.append(group_df)
return pd.concat(X_scaled).sort_index()
from sklearn.base import clone
df = pd.DataFrame({'group': [1, 1, 1, 2, 2, 2], 'x': [1, 2, 3, 10, 20, 30]})
scaler = GroupScaler(group_column='group')
scaled_df = scaler.fit_transform(df)
print(scaled_df)
Output:
x group
0 -1.224745 1
1 0.000000 1
2 1.224745 1
3 -1.224745 2
4 0.000000 2
5 1.224745 2