python-3.xoverridingfactory-method

Override the __init__ method of a base class to edit a string literal


So let's consider the following classes. I have no control over classes A, B, and C and can't change the implementation.

class A(object):
   def __new__(cls, root=None):
      if root == "Some conditions":
         return super(A, cls).__init__(B)
      else:
         return super(A, cls).__init__(C)
   
   def __init__(self) -> None:
      local_variable = self._join(self.root, 'x')  # Need to modify this string to make it work with my problem. I believe this should be overridden in a child class.
      self.x = self.foo(local_variable)
   
   
class B(A):
   def __init__(self, root='local path to the downloaded file'):
      self.root = root
      self._join = os.path.join
      super(self.__class__, self).__init__()


class C(A):
   def __init__(self, root='address of the remote server to download the file'):
      self.root = root
      super(self.__class__, self).__init__()

What I'm trying to do is overriding the class A's __init__ method to change the local_variable (x is hard coded and needs to be modified) and consequently changing the self.x.

What I have tried to do is:

class D(A):
   def __new__(cls, root=None):
      return super().__new__(self, root)

   def __init__(self):
      super(D, self).__init__()
         # Not sure how to proceed and change the local_variable and self.x

I was reading this answer and was wondering if class A __new__ method will make the overriding of the __init__ method more complicated or not?

Note: I simplified the classes to make it easier to understand the problem.


Solution

  • You can create a wrapper function that parses the source code of a given function as AST, walks through the AST nodes to replace the value of a string node (of type ast.Str) where the original value matches 'x' and recompiles the AST back to a function object:

    import ast
    import inspect
    from textwrap import dedent
    
    def replace_str(target, replacement):
        def wrapper(func):
            tree = ast.parse(dedent(inspect.getsource(func)))
            for node in ast.walk(tree):
                if isinstance(node, ast.Str) and node.s == target:
                    node.s = replacement
            ast.fix_missing_locations(tree)
            scope = {}
            exec(compile(tree, inspect.getfile(func), "exec"), func.__globals__, scope)
            return scope[func.__name__]
        return wrapper
    

    Then apply this wrapper on A.__init__ (simplified for illustration purpose):

    class A:
        def __init__(self):
            print('x')
    
    A.__init__ = replace_str('x', 'y')(A.__init__)
    A()
    

    This outputs:

    y
    

    Demo: https://replit.com/@blhsing/IvoryWhisperedOpendoc