pythonxgboostsurvival-analysis

XGBoost AFT survival model with external memory iterator


How to make XGBoost external memory and XGBoost survival AFT model work together?

Background: I've written XGBoost iterator for batched training as in the linked example. Now I want to train an AFT model from the xgboost library. The problem is the XGB DMatrix, for which we need to run set_float_info to set survival censoring intervals. For example:

dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index])
dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index])

Attached please find my redacted code (can't attach everything, but that's the problematic gist). I got the censoring time data in df, but I don't know how to "attach" it to Xy_train.

class BatchedParquetIterator(xgboost.DataIter):
  def __init__(
      self
      ):
    # ...    
    super().__init__(cache_prefix=os.path.join(".", "cache"))

  def next(self, input_data: Callable):
    """Advance the iterator by 1 step and pass the data to XGBoost.  This function is
    called by XGBoost during the construction of ``DMatrix``
    """
    
    if self._it == len(self._file_paths):
      return 0  # return 0 to let XGBoost know this is the end of iteration
    
    df = pd.read_parquet(self._file_paths[self._it])
    X, y = self._preprocess(df)
    
    input_data(data=X, label=y)
    self._it += 1

    return 1  # Return 1 to let XGBoost know we haven't seen all the files yet.

  def reset(self):
    """Reset the iterator to its beginning"""
    self._it = 0

  def _preprocess(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    # ...
    return X, y


parquet_iterator_train = BatchedParquetIterator(batches)
Xy_train = xgboost.DMatrix(parquet_iterator_train)

Solution

  • Turns out it's easy. The documentation states that:

    input_data is a function passed in by XGBoost who has the exact same signature of DMatrix.

    Interestingly, lower and upper bounds can be passed not only through set_float_info (as in the AFT tutorial), but also through the DMatrix constructor (see the documentation).

    All in all, one needs to simply change one line in the aftorementioned code:

    class BatchedParquetIterator(xgboost.DataIter):
      # ...
      def next(self, input_data: Callable):
        # ...
        input_data(data=X, label=y, label_lower_bound=llb, label_upper_bound=lub)
        # ...
    

    where llb and lub are the arrays defining the considered intervals.