pythonpython-3.xparallel-processingreinforcement-learningmonte-carlo-tree-search

Parallelizing Monte Carlo Tree Search


I have a Monte Carlo Tree Search implementation that I need to optimize. So I thought about parallelizing the rollout phase. How to do that? (Is there a code example). Are there any python modules etc that you would recommend?

I apologize if this isn't the right place to post this.


Solution

  • you didn't give a demo code so It's hard to totally solve your problem

    class MCTS:
        ....
        def _run_search(self):
            ...
            for node in nodes:
                node.reward = self._rollout(node) # rollout
            ...
    

    And with multiprocessing, you can:

    from multiprocessing import Pool
    class MCTS:
        ....
        def _run_search(self):
            ...
            with Pool(os.cpu_count()-2) as p:
                result = p.map(self._rollout, nodes)
            # for node in nodes:
            #    node.reward = self._rollout(node) # rollout
            ...
    

    So if you can give a demo code, the problem will be more clear.

    And by the way, you can always replace a for loop by multiprocessing like this:

    We see this situation: for loop

    # inp
    data = [1,2,3,4,5]
    def f(x):
        return x**2
    
    # processing (for loop)
    result = []
    for i in data:
        result.append(f(i))
    
    # out
    print(result) # [1,4,9,16,25]
    

    Method 0: multiprocessing (what you want~)

    # inp
    data = [1,2,3,4,5]
    def f(x):
        return x**2
    
    # processing (multiprocessing)
    with Pool(os.cpu_count()-2) as p:
        result = p.map(f, data)
    
    # out
    print(result) # [1,4,9,16,25]
    

    Method 2: list comprehensions

    # inp
    data = [1,2,3,4,5]
    def f(x):
        return x**2
    
    # processing (list comprehension)
    result = [f(i) for i in data]
    
    # out
    print(result) # [1,4,9,16,25]
    

    Method 3: map + lambda expression

    # inp
    data = [1,2,3,4,5]
    def f(x):
        return x**2
    
    # processing (lambda + map)
    result = map(lambda x: f(x), data)
    
    # out
    print(result) # [1,4,9,16,25]
    
    1. for loop
    2. multiprocessing
    3. list comprehension
    4. map + lambda expression

    As you see, for loop, multiprocessing, list comprehensive, map+lambda is four way to do one thing batchly