How to make XGBoost external memory and XGBoost survival AFT model work together?
Background: I've written XGBoost iterator for batched training as in the linked example.
Now I want to train an AFT model from the xgboost
library.
The problem is the XGB DMatrix
, for which we need to run set_float_info
to set survival censoring intervals. For example:
dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index])
dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index])
Attached please find my redacted code (can't attach everything, but that's the problematic gist).
I got the censoring time data in df
, but I don't know how to "attach" it to Xy_train
.
class BatchedParquetIterator(xgboost.DataIter):
def __init__(
self
):
# ...
super().__init__(cache_prefix=os.path.join(".", "cache"))
def next(self, input_data: Callable):
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is
called by XGBoost during the construction of ``DMatrix``
"""
if self._it == len(self._file_paths):
return 0 # return 0 to let XGBoost know this is the end of iteration
df = pd.read_parquet(self._file_paths[self._it])
X, y = self._preprocess(df)
input_data(data=X, label=y)
self._it += 1
return 1 # Return 1 to let XGBoost know we haven't seen all the files yet.
def reset(self):
"""Reset the iterator to its beginning"""
self._it = 0
def _preprocess(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
# ...
return X, y
parquet_iterator_train = BatchedParquetIterator(batches)
Xy_train = xgboost.DMatrix(parquet_iterator_train)
Turns out it's easy. The documentation states that:
input_data
is a function passed in by XGBoost who has the exact same signature ofDMatrix
.
Interestingly, lower and upper bounds can be passed not only through set_float_info
(as in the AFT tutorial), but also through the DMatrix
constructor (see the documentation).
All in all, one needs to simply change one line in the aftorementioned code:
class BatchedParquetIterator(xgboost.DataIter):
# ...
def next(self, input_data: Callable):
# ...
input_data(data=X, label=y, label_lower_bound=llb, label_upper_bound=lub)
# ...
where llb
and lub
are the arrays defining the considered intervals.