As the title says, I'm trying to increment a variable after each training iteration using the Catboost Classifier, to update a progress bar in a gui, and I can't seem to find anything about it on the internet.
So far I've tried to make a class containing an after_iteration function and passing that or an instance of it as a parameter the fit method:
class Test_Callback():
def after_iteration(self, info):
global progress
progress += 1
test = Test_Callback()
model.fit(train_data, eval_set=test_data, callbacks=[test])
#Alternatively pass the class directly
model.fit(train_data, eval_set=test_data, callbacks=[Test_Callback()])
This does increment the variable, but also ends the training after a single iteration.
I've also tried passing an after_iteration() - function to the callbacks parameter directly, but it just crashes the script.
I can't test it but I found example callbacks in catboot's repo MetricsCheckerCallback and EarlyStopCallback and it seems it uses returned value (True
/False
) to continue or stop iterartions.
Because you don't use return
so it automatically returns None
(which is treated as False
) and this can stop iterations.
You may need to add return True
class Test_Callback():
def after_iteration(self, info):
global progress
progress += 1
return True
BTW:
I would put progess
inside class
class Test_Callback():
def __init__(self):
self.progress = 0
def after_iteration(self, info):
self.progress += 1
return True
but it would need to create class before fit()
to have access self.progress
after running fit()
test = Test_Callback()
model.fit(train_data, eval_set=test_data, callbacks=[test])
print(test.progress)
Example EarlyStopCallback
also suggests that you can use info.iteration
for it
(but it didn't test it)
class Test_Callback():
def __init__(self):
self.progress = 0
def after_iteration(self, info):
self.progress = info.iteration
return True
Example callbacks from Catboost repo - so nobody have to visit links:
class MetricsCheckerCallback:
def after_iteration(self, info):
for dataset_name in ['learn', 'validation_0', 'validation_1']:
assert dataset_name in info.metrics
for metric_name in metric_names:
assert metric_name in info.metrics[dataset_name]
assert len(info.metrics[dataset_name][metric_name]) == info.iteration
return True
model.fit(train_data, train_labels,
callbacks=[MetricsCheckerCallback()],
eval_set=[validation_0, validation_1])
and
class EarlyStopCallback:
def __init__(self, stop_iteration):
self._stop_iteration = stop_iteration
def after_iteration(self, info):
return info.iteration != self._stop_iteration
model.fit(train_data, train_labels, callbacks=[
EarlyStopCallback(7),
EarlyStopCallback(5),
EarlyStopCallback(6)
])