Here's a batch prediction case using multiprocessing. Steps:
After with mp.Pool(processes=num_processes) as pool
, there's a with Dataset(dataset_code) as data
in the main process using websocket to get data, and it works well.
Then goes to the multiprocessing mission with
result = pool.apply_async(pred_data_from_db, args=(start_index, chunk))
predict = getattr(module, customized_pred_func_name)
Other Information:
the model has been loaded before multiprocessing starts
there is a flask interface for other services to know the progress of the batch prediction
this issue will not happen when predicting with an svm pkl model
I already set the n_job to 1
I hope I can find ways to locate the problem
As sample of the input data and exact error message not given, I give you only the possible causes to resolve the issue. Don't copy and paste the code as it is only guidance. Instead use the techniques I've described. Kindly note that this is a guidance towards the trouble shooting and not the complete solution.
The issues could be:
1.XGBoost's internal parallelism conflicting with multiprocessing
. Even with n_jobs=1, XGBoost may still use OpenMP.
2.When child processes in a mp.Pool inherit shared global variables in unpredictable ways, It leads to crashes or hangs(Global State Contamination)
Verify Model Loading in Child Processes:
def pred_data_from_db(start_index, chunk):
# Load model inside worker (ensures fresh state)
model = xgb.Booster()
model.load_model('model.xgb') # or your pkl loading method
predict = getattr(module, customized_pred_func_name)
return predict(model, chunk)
Disable XGBoost's internal threading and use single-thread mode:
import os
# before importing xgboost, set OMP_NUM_THREADS and OPENBLAS_NUM_THREADS as 1
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
import xgboost as xgb
use spawn
instead of fork
as fork
is meant only for linux.(In case if use fork
)
if __name__ == '__main__':
mp.set_start_method('spawn') # Add before Pool creation
with mp.Pool() as pool:
Add detailed error logging with prediction in try-catch
def pred_data_from_db(start_index, chunk):
try:
predict = getattr(module, customized_pred_func_name)
return predict(model, chunk)
except Exception as e:
import traceback
with open('error.log', 'a') as f:
f.write(f"Failed on chunk {start_index}:\n")
f.write(traceback.format_exc())
raise
Test with dummy data and find if the issue is data dependent
# Replace real data with simple test data
dummy_chunk = [[1,2,3]] * 100
pool.apply_async(pred_data_from_db, args=(0, dummy_chunk))
If you're using flask with multiprocessing, Ensure no flask context is passed to workers. Also Use mp.Manager()
for shared progress tracking:
manager = mp.Manager()
progress = manager.dict()
def worker():
progress['status'] = 'running'