I want to write a multithread code that creates one inference session and run multiple images at the same time. I found this code created a process, is there any parameters in Process class that could achieve that or should I include multithreading class into this? Basically what should I do to achieve that based on the code below?
import onnxruntime as ort
import numpy as np
import multiprocessing as mp
def init_session(model_path):
EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
sess = ort.InferenceSession(model_path, providers=EP_list)
return sess
class PickableInferenceSession: # This is a wrapper to make the current InferenceSession class pickable.
def __init__(self, model_path):
self.model_path = model_path
self.sess = init_session(self.model_path)
def run(self, *args):
return self.sess.run(*args)
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
self.model_path = values['model_path']
self.sess = init_session(self.model_path)
class IOProcess(mp.Process):
def __init__(self):
super(IOProcess, self).__init__()
self.sess = PickableInferenceSession('model.onnx')
def run(self):
print("calling run")
print(
self.sess.run({}, {
'a': np.zeros((3, 4), dtype=np.float32),
'b': np.zeros((4, 3), dtype=np.float32)
}))
#print(self.sess)
if __name__ == '__main__':
mp.set_start_method(
'spawn') # This is important and MUST be inside the name==main block.
io_process = IOProcess()
io_process.start()
io_process.join()
I'm not certain what you are asking, but you may want to try import multiprocess.dummy as mp
then mp.Process
is actually uses threading. If you are asking if there's a keyword to perform multithreading from a process-based Process
class, the answer is to instead use the threaded Process
. You'll also note that I'm suggesting multiprocess
instead of multiprocessing
not only because I'm the author, but because multiprocess
uses dill
for pickling to generally simplify the pickling of objects to send to another process/thread. Also, it seems that from your code, you might want to just try a Pool
to run in parallel:
>>> import multiprocess as mp # or multiprocess.dummy
>>> p = mp.Pool()
>>> p.map(lambda x:x, range(4))
[0, 1, 2, 3]
>>> p.close(); p.join()