pythonpython-3.12

How to solve problem of circular reference when defining a tree in python 3.12?


I try to follow this tutoriel. I am using python 3.12 as required

I try to run this code:

from dataclasses import dataclass
from typing import Callable

type RoseTree[T] = Branch[T] | Leaf[T]

@dataclass
class Branch[A]:
    branches: list[RoseTree[A]]

    def map[B](self, f: Callable[[A], B]) -> Branch[B]:
        return Branch([b.map(f) for b in self.branches])


@dataclass
class Leaf[A]:
    value: A

    def map[B](self, f: Callable[[A], B]) -> Leaf[B]:
        return Leaf(f(self.value))

And I get the following error:

Traceback (most recent call last):
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 8, in <module>
    class Branch[A]:
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 8, in <generic parameters of Branch>
    class Branch[A]:
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 11, in Branch
    def map[B](self, f: Callable[[A], B]) -> Branch[B]:
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 11, in <generic parameters of map>
    def map[B](self, f: Callable[[A], B]) -> Branch[B]:
                                             ^^^^^^
NameError: name 'Branch' is not defined

If I try to invert the definition of RoseTree and Branch, Leaf

Traceback (most recent call last):
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 6, in <module>
    class Branch[A]:
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 6, in <generic parameters of Branch>
    class Branch[A]:
  File "/mnt/c/Users/Pierre-Olivier/Documents/python/3.12/a.py", line 7, in Branch
    branches: list[RoseTree[A]]
                   ^^^^^^^^
NameError: name 'RoseTree' is not defined

I suppose that it is a problem of circular reference. RoseTree can't be built because Branch isn't defined or Branch can't be built because RoseTree isb't defined


Solution

  • Use strings for forward references:

    from dataclasses import dataclass
    from typing import Callable
    
    type RoseTree[T] = Branch[T] | Leaf[T]
    
    @dataclass
    class Branch[A]:
        branches: list[RoseTree[A]]
    
        def map[B](self, f: Callable[[A], B]) -> 'Branch[B]':
            return Branch([b.map(f) for b in self.branches])
    
    
    @dataclass
    class Leaf[A]:
        value: A
    
        def map[B](self, f: Callable[[A], B]) -> 'Leaf[B]':
            return Leaf(f(self.value))