I want use callbacks and eval_set etc. but i have a problem:
from sklearn.multiclass import OneVsRestClassifier
import lightgbm
verbose = 100
params = {
"objective": "binary",
"n_estimators": 500,
"verbose": 0
}
fit_params = {
"eval_set": eval_dataset,
"callbacks": [CustomCallback(verbose)]
}
clf = OneVsRestClassifier(lightgbm.LGBMClassifier(**params))
clf.fit(X_train, y_train, **fit_params)
how i can hand over fit_params to my estimator? I get
----------------------------------------------------------------------
---> 13 clf.fit(X_train, y_train, **fit_params)
TypeError: OneVsRestClassifier.fit() got an unexpected keyword argument 'eval_set'
Per scikit-learn
's docs for OneVsRestClassifier
(link), as of v1.4.0 additional **fit_params
are only passed through to estimators' fit()
methods if you've enabled what scikit-learn
calls "metadata routing".
There are 2 required steps which are missing in your example:
sklearn.set_config(enable_metadata_routing=True)
scikit-learn
to pass through eval_set
and callbacks
, via .set_fit_request()
.Consider this minimal, reproducible example using Python 3.11, lightgbm==4.3.0
, and scikit-learn==1.4.1
.
import lightgbm as lgb
import sklearn
from sklearn.datasets import make_blobs
from sklearn.multiclass import OneVsRestClassifier
# enable metadata_routing
sklearn.set_config(enable_metadata_routing=True)
# create datasets
X, y = sklearn.datasets.make_blobs(
n_samples=10_000,
n_features=10,
centers=2
)
eval_dataset = lgb.Dataset(X, label=y)
eval_results = {}
# construct estimator
params = {
"objective": "binary",
"n_estimators": 10,
}
fit_params = {
"eval_set": (X, y),
"callbacks": [lgb.record_evaluation(eval_results)]
}
clf = OneVsRestClassifier(
lgb.LGBMClassifier(**params)
.set_fit_request(callbacks=True, eval_set=True)
)
# train
clf.fit(X, y, **fit_params)
# check eval results, to prove that the callback was used
print(eval_results)
# {'valid_0': OrderedDict([('binary_logloss', [0.598138869381609, 0.5203293282602738, 0.45544446427154844, 0.40059849184355334, 0.3537472248673818, 0.31338812592304066, 0.2783839141567028, 0.24785302530927006, 0.22109850424011224, 0.19756016345789282])])}