pythonmultiprocessingpython-multiprocessingxgboost

Process hangs when multiprocessing with XGBoost model batch prediction


Here's a batch prediction case using multiprocessing. Steps:

  1. After with mp.Pool(processes=num_processes) as pool, there's a with Dataset(dataset_code) as data in the main process using websocket to get data, and it works well.

  2. Then goes to the multiprocessing mission with

result = pool.apply_async(pred_data_from_db, args=(start_index, chunk))
  1. Inside pred_data_from_db, users can import their own predict python file flow with
predict = getattr(module, customized_pred_func_name)
  1. The issue comes when XGBoost pkl model runs into predict part and it just stops

Other Information:

  1. the model has been loaded before multiprocessing starts

  2. there is a flask interface for other services to know the progress of the batch prediction

  3. this issue will not happen when predicting with an svm pkl model

  4. I already set the n_job to 1

I hope I can find ways to locate the problem


Solution

  • As sample of the input data and exact error message not given, I give you only the possible causes to resolve the issue. Don't copy and paste the code as it is only guidance. Instead use the techniques I've described. Kindly note that this is a guidance towards the trouble shooting and not the complete solution.

    The issues could be:

    1.XGBoost's internal parallelism conflicting with multiprocessing. Even with n_jobs=1, XGBoost may still use OpenMP.

    2.When child processes in a mp.Pool inherit shared global variables in unpredictable ways, It leads to crashes or hangs(Global State Contamination)

    Verify Model Loading in Child Processes:

    def pred_data_from_db(start_index, chunk):
        # Load model inside worker (ensures fresh state)
        model = xgb.Booster()
        model.load_model('model.xgb')  # or your pkl loading method
        
        predict = getattr(module, customized_pred_func_name)
        return predict(model, chunk)
        
        
    

    Disable XGBoost's internal threading and use single-thread mode:

    import os
    
    # before importing xgboost, set OMP_NUM_THREADS and OPENBLAS_NUM_THREADS as 1
    os.environ['OMP_NUM_THREADS'] = '1' 
    os.environ['OPENBLAS_NUM_THREADS'] = '1'
    
    import xgboost as xgb
    

    use spawn instead of fork as fork is meant only for linux.(In case if use fork)

    if __name__ == '__main__':
        mp.set_start_method('spawn')  # Add before Pool creation
        with mp.Pool() as pool:
           
           
    

    Add detailed error logging with prediction in try-catch

    def pred_data_from_db(start_index, chunk):
        try:
            predict = getattr(module, customized_pred_func_name)
            return predict(model, chunk)
        except Exception as e:
            import traceback
            with open('error.log', 'a') as f:
                f.write(f"Failed on chunk {start_index}:\n")
                f.write(traceback.format_exc())
            raise
            
    

    Test with dummy data and find if the issue is data dependent

    # Replace real data with simple test data
    dummy_chunk = [[1,2,3]] * 100  
    
    pool.apply_async(pred_data_from_db, args=(0, dummy_chunk))
    

    If you're using flask with multiprocessing, Ensure no flask context is passed to workers. Also Use mp.Manager() for shared progress tracking:

    manager = mp.Manager()
    progress = manager.dict()
    
    def worker():
        progress['status'] = 'running'