pythonpandasmultiprocessinglarge-data

Python Multiprocessing: when I launch many processes on a huge pandas data frame, the program gets stuck


I am trying to gain execution time with python's multiprocessing library (pool_starmap) on a code that executes the same task in parallel, on the same Pandas DataFrame, but with different call arguments.

When I execute this code on a small portion of the data frame and with 10 jobs, everything works just fine. However when I put the whole 100 000 000 lines dataset with 63 jobs (using a cluster computer with 64 CPU cores for this), the code just... freezes. It is running, but not doing anything (I know it because, once every 10 000 task, the code is supposed to print that it is alive).

I have searched and found similar issues on the internet, but none of the answers applied to my specific case, so here I am.

Minimal Example

I have made a minimal, self-sustaining example to reproduce this problem. Let's say to simplify that my data frame has 2 columns; the 1st one being "stores", the other is "price". I want to recover the mean_price for each store. Of course in this specific problem one would just groupBy the dataframe on stores and aggregate over the price but this is a simplification; let's assume that the task can only be done one store at a time (this is my case). Here's what a minimal example looks like:

#minimal example
#changes according to SIGHUP and Frank Yellin

import time
import pandas as pd
import random as rd
import multiprocessing as mp

import psutil #RAM usage

def create_datafile(nrows):
    """
    create a random pandas dataframe file
    To visualize this rather simple example,
    let's say that we are looking at a pool of 0.1*nrows products across different stores,
    that can have different values of the attribute "price"
    (in the list "stores").
    """
    
    price = [rd.randint(0,300) for i in range(nrows)]
    stores = [i%(0.1*nrows) for i in range(nrows)]

    data=zip(stores,price)
 
    return pd.DataFrame(data=data, columns=["stores", "price"])

            
def task(store):
    global data
    """
    the task we want to accomplish: compute mean price
    for each store in the dataframe.
    """

    if rd.randint(1,10000)==1:
        print('I am alive!')

    product_df = data.groupby('stores', as_index = False).agg(mean_price = ("price", 'mean'))

    selected_store = product_df[product_df['stores'] == store] #select the mean for a given store

    return (store, selected_store['mean_price'])

def pinit(_df):
    global data 
    data = _df

def get_ram_usage_pct(): 
    #source: https://www.pragmaticlinux.com/2020/12/monitor-cpu-and-ram-usage-in-python-with-psutil/
    """
    Obtains the system's current RAM usage.
    :returns: System RAM usage as a percentage.
    :rtype: float
    """
    return psutil.virtual_memory().percent


if __name__ == "__main__":
    ##
    nrows=100000000
    nb_jobs= 63

    print('Creating data...')
    df = create_datafile(nrows)
    print('Data created.')

    print('RAM usage after data creation is {} %'.format(get_ram_usage_pct()))


    stores_list = [i%(0.1*nrows) for i in range(nrows)]

    dic_mean={}
    #launch multiprocessing tasks with starmap
    tic=time.time()
    print(f'Max number of jobs: {mp.cpu_count() - 1}')
    print(f'Running: {min(nb_jobs, mp.cpu_count() - 1)} jobs...')
    with mp.Pool(initializer=pinit, initargs=(df,), processes=min(nb_jobs, mp.cpu_count() - 1)) as pool:
        for store,result in pool.map_async(task, stores_list).get():
            dic_mean[store] = result[store]
    toc=time.time()
    print(f'Processed data in {round((toc-tic)/60,1)} minutes (rounded to 0.1).') 

    #print(dic_mean)
    #dic_means now contains all the means computed by each program.

I am using Python 3.9.2.

If you launch this code with:

I am rather new to Python Multiprocessing, so any hint would be welcome ! Thanks in advance.


Solution

  • Ok, so thanks to @SIGHUP and @Frank Yellin, I was able to find the issue, so I will share it here if anyone encounters a similar issue.

    Python seems unable to print anything when there are too many concurrent processes running.

    The solution, to check if your program is alive, is to make it write into a .txt file, for example. Once there are too many processes, print statements will NOT appear in the Python console.

    I don't know if the print statements freeze the entire program or if it continues running. However I would suggest removing any print statement to avoid any surprise.

    Here is a way to make the code from my example work without hassle (beware, 1 000 000 rows or more will take a long time):

    #minimal example
    #changes according to SIGHUP and Frank Yellin
    
    import time
    import pandas as pd
    import random as rd
    import multiprocessing as mp
    
    import psutil #RAM usage
    
    import sys
    
    def create_datafile(nrows):
        """
        create a random pandas dataframe file
        To visualize this rather simple example,
        let's say that we are looking at a pool of 0.1*nrows products across different stores,
        that can have different values of the attribute "price"
        (in the list "stores").
        """
        
        price = [rd.randint(0,300) for i in range(nrows)]
        stores = [i%(0.1*nrows) for i in range(nrows)]
    
        data=zip(stores,price)
     
        return pd.DataFrame(data=data, columns=["stores", "price"])
    
                
    def task(store):
        global data
        global alive_file
        """
        the task we want to accomplish: compute mean price
        for each store in the dataframe.
        """
    
        #print('I am alive!',flush=True) DO NOT put a print statement
    
        with open(alive_file, 'a') as f:
            f.write("I am alive !")
    
        product_df = data.groupby('stores', as_index = False).agg(mean_price = ("price", 'mean'))
    
        selected_store = product_df[product_df['stores'] == store] #select the mean for a given store
    
        return (store, selected_store['mean_price'])
    
    def pinit(_df, _alive_file):
        global data 
        global alive_file
        data = _df
        alive_file = _alive_file
    
    def get_ram_usage_pct(): #source: https://www.pragmaticlinux.com/2020/12/monitor-cpu-and-ram-usage-in-python-with-psutil/
        """
        Obtains the system's current RAM usage.
        :returns: System RAM usage as a percentage.
        :rtype: float
        """
        return psutil.virtual_memory().percent
    
    
    if __name__ == "__main__":
        ##
        nrows= int(sys.argv[1]) #number of rows in dataframe
        nb_jobs= int(sys.argv[2]) #number of jobs
    
    
        print('Creating data...')
        tic=time.time()
        df = create_datafile(nrows)
        toc=time.time()
        print(f'Data created. Took {round((toc-tic)/60,1)} minutes (rounded to 0.1)')
    
        print('RAM usage after data creation is {} %'.format(get_ram_usage_pct()))
    
        #print(data_df)
    
        #create parameters for multiprocessing task
        stores_list = [(i % (0.1 * nrows),) for i in range(nrows)]
        #dics_stores=[{} for _ in stores_list]
        #parameters = [(df, stores_list[i]) for i in range(nrows)]
    
        dic_mean={}
        #launch multiprocessing tasks with starmaps
        tic=time.time()
        print(f'Max number of jobs: {mp.cpu_count() - 1}')
        print(f'Running: {min(nb_jobs, mp.cpu_count() - 1)} jobs...')
        with mp.Pool(initializer=pinit, initargs=(df,"alive.txt",), processes=min(nb_jobs, mp.cpu_count() - 1)) as pool:
            for store,result in pool.starmap_async(task, stores_list).get():
                dic_mean[store] = result[store]
        toc=time.time()
        print(f'Processed data in {round((toc-tic)/60,1)} minutes (rounded to 0.1).') 
    
        #print(dic_mean)
        #dic_means now contains all the means computed by each program.
    

    Thank you to everyone who took the time to examine my issue and made my code better, helping me identify the true issue.