I am trying to pass a parameter DummyTransformer__feature_index_sec
to my sklearn custom transformer via a pipeline. It seems like I need to implement metadata routing in order to do this. However, I cannot successfully create a working dummy example:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.utils.metadata_routing import MetadataRouter, MethodMapping
from scipy.sparse import csr_matrix
import pandas as pd
import numpy as np
from sklearn import set_config
# Enable metadata routing globally
set_config(enable_metadata_routing=True)
class DummyTransformer(BaseEstimator, TransformerMixin):
def transform(self, X, feature_index_sec=None):
if feature_index_sec is None:
raise ValueError("Missing required argument 'feature_index_sec'.")
print(f"Received feature_index_sec with shape: {feature_index_sec.shape}")
return X
def fit(self, X, y=None, feature_index_sec=None, **fit_params):
return self
def transform(self, X, feature_index_sec=None):
return X
def fit_transform(self, X, y=None, feature_index_sec=None):
# self.fit(X, y) - fit is stateless in this transformer!
return self.transform(X, feature_index_sec)
def get_metadata_routing(self):
print("Configuring metadata routing for DummyTransformer")
router = (
MetadataRouter(owner=self.__class__.__name__)
.add_self_request(self) # Declare this transformer as a consumer
)
return router
# Declare explicitly what metadata is required for each method
def set_fit_request(self, **metadata):
self._fit_request = metadata
return self
def set_transform_request(self, **metadata):
self._transform_request = metadata
return self
def set_fit_transform_request(self, **metadata):
self._fit_transform_request = metadata
return self
# Dummy data
feature_matrix = csr_matrix(np.random.rand(10, 5))
train_idx = pd.DataFrame({'FileDate_ClosingPrice': np.random.rand(10)})
# Configure metadata requests for DummyTransformer
transformer = DummyTransformer().set_fit_transform_request(feature_index_sec=True)
# Minimal pipeline
pipe = Pipeline(steps=[('DummyTransformer', transformer)])
# Test fit_transform
pipe.fit_transform(feature_matrix, DummyTransformer__feature_index_sec=train_idx)
The example above results in an error: Pipeline.fit_transform got unexpected argument(s) {'DummyTransformer__feature_index_sec'}, which are not routed to any object.
An alternative way that does not require metadata routing is to pass the value of feature_index_sec
using **params
(see: Pipeline.fit
). In scikit-learn, the fit method is typically used to store parameters that will be required later during the transformation. You may also pass feature_index_sec
to transform
if it is different at that stage.
The adjusted DummyTransformer
would look as follows:
class DummyTransformer(BaseEstimator, TransformerMixin):
def __init__(self):
self.feature_index_sec = None # initialize attribute
def transform(self, X, feature_index_sec=None, **fit_params):
if feature_index_sec is None:
raise ValueError("Missing required argument 'feature_index_sec'.")
print(f"Transform Received feature_index_sec with shape: {feature_index_sec.shape}")
return X
def fit(self, X, y=None, feature_index_sec=None, **fit_params):
print(f"Fit Received feature_index_sec with shape: {feature_index_sec.shape}")
return self
def fit_transform(self, X, y=None, feature_index_sec=None, **fit_params):
self.fit(X, y, feature_index_sec, **fit_params) # feature_index_sec is passed with other parameters
return self.transform(X, feature_index_sec, **fit_params)
Now, running the example dummy data and pipeline:
feature_matrix = csr_matrix(np.random.rand(10, 5))
train_idx = pd.DataFrame({'FileDate_ClosingPrice': np.random.rand(10)})
transformer = DummyTransformer()
pipe = Pipeline(steps=[('DummyTransformer', transformer)])
pipe.fit_transform(feature_matrix, DummyTransformer__feature_index_sec=train_idx)
Prints:
Fit Received feature_index_sec with shape: (10, 1)
Transform Received feature_index_sec with shape: (10, 1)
<10x5 sparse matrix of type '<class 'numpy.float64'>'
with 50 stored elements in Compressed Sparse Row format>