pythonpycharmoutputbackpropagation

The output changes in a seemingly deterministic code


I have this class that helps me calculate gradient during backpropagation.

import time
visited = set()
class Value:
    def __init__(self, data, _children=None, _op=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children) if _children else set()
        self._op = _op

    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')

        def _backward():
            self.grad += 1.0 * out.grad
            other.grad += 1.0 * out.grad
        out._backward = _backward

        return out

    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')

        def _backward():
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad
        out._backward = _backward
        return out

    def backward(self, output=False):
        if output:
            self.grad = 1.0
        self._backward()
        for child in self._prev:
            if child not in visited:
                visited.add(child)
                child.backward()

    def __repr__(self):
        return f"Value(data={self.data}, grad={self.grad})"

a = Value(2.0)
b = Value(3.0)
c = a + b
d = a * b
e = c + d
g = c * d
f = e + g

f.backward(True)
print(f"a grad: {a.grad}")

This code seems deterministic, and I don't use multithreading, but I get different outputs when try to run it. Sometimes it outputs:

a grad: 21.0

other times:

a grad: 4.0

I'm using Pycharm with python 3.12.

I also tried to run it in a playground https://www.boot.dev/playground/py, which gives me different outputs as well.

P.S. I know that backward() doesn't implement backpropagation correctly.


Solution

  • Your

    self._prev = set(_children) ...
    

    makes the later iteration for child in self._prev: non-deterministic, as sets aren't ordered so you get some arbitrary order.

    If you instead use

    self._prev = dict.fromkeys(_children) ...
    

    as an "ordered set", it becomes deterministic and the result is always 4.

    And with

    self._prev = dict.fromkeys(reversed(_children)) ...
    

    you get a different deterministic order and the result is always 21.