pythonlinuxmultiprocessing

Run multiprocessing inside class method


I'm trying to run multiprocessing inside a class method.

I don't understand, after get some error by passing an class attribute/ class method inside mapping function of pool multiprocessing. I move it outside the class but still get the error.

My job function is:

def run_a_request(client,  prompt, stream=False):
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "user",
                "content": prompt
            }
        ],
        temperature=0.1,
        stream=stream
    )
    if not stream:
        response = response.choices[0].message.content

    return response

The calling function:

   def run_intermediate_response(self, subgraph, components, question):

        intermediate_response = ""
        prompts = []
        for idx in range(len(components)):
            tmp = dict()
            tmp['components'] = [subgraph['components'][idx]]

            knowledge = self.__graph_linearizor.run(tmp)[0][0]
            prompts.append(generate_prompt_for_global_search(knowledge, question))

        pool = Pool(5)
        result = pool.map(run_a_request, zip(repeat(client), prompts))

        print(result)

        return 1

I get this error:

  File "/home/ju/PycharmProjects/kgqa_graphrag/framework/components/graph_retriever/base/base_retrieval/subgraph_retriever.py", line 156, in run_intermediate_response
    result = pool.map(run_a_request, zip(repeat(client), prompts))
  File "/home/ju/anaconda3/envs/kgqa_graphrag/lib/python3.9/multiprocessing/pool.py", line 364, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "/home/ju/anaconda3/envs/kgqa_graphrag/lib/python3.9/multiprocessing/pool.py", line 771, in get
    raise self._value
  File "/home/ju/anaconda3/envs/kgqa_graphrag/lib/python3.9/multiprocessing/pool.py", line 537, in _handle_tasks
    put(task)
  File "/home/ju/anaconda3/envs/kgqa_graphrag/lib/python3.9/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/home/ju/anaconda3/envs/kgqa_graphrag/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: cannot pickle '_thread.RLock' object

Is there anyway to fix this? Thank you


Solution

  • This is not about being inside or outside a class method.

    (Presumably) the client object contains a threading.RLock object that can't be pickled, and everything that needs to be sent across multiprocessing boundaries needs to be pickleable.

    To fix this, don't send the client over to the child process, and have the run_a_request() function (e.g. optionally, if you need to run this both single-processingly and otherwise) construct its own client object.