rustbackpropagationautomatic-differentiation

How can PyTorch-like automatic differentiation work in Rust, given that it does not allow multiple mutable references?


I'm mostly an outsider trying to understand if Rust is appropriate for my projects.

There are frameworks that do automatic differentiation in Rust. Specifically, candle, and some other projects, I think, somehow do it in a way that's similar to PyTorch, according to their description.

However, I know that Rust does not allow multiple mutable references. And it seems like that is what's needed for PyTorch-like automatic differentiation:

x = torch.rand(10) # an array of 10 elements
x.requires_grad = True

y = x.sin()
z = x**2

Both y and z must keep mutable references to x, because you might want to backpropagate them, which will modify x.grad. For example:

(y.dot(z)).backwards()
print(x.grad) # .backwards() adds a new field (an array) to x, without modifying it otherwise

So how can similar behavior be implemented in Rust, given that it does not allow multiple mutable references?


Solution

  • You are correct that the rust compiler enforces that there can only be one mutable reference to a value at a time, but there is an escape hatch: the interior mutability pattern.

    This pattern allows programmers to construct data structures for which the rules are checked at run time instead of compile time.

    The standard library provides a number of containers that implement interior mutability, with different usage patterns suitable for different scenarios. Key examples are:

    There are others - see the module level documentation for cell and sync.

    How does this apply to candle? Let's take a peek under the hood:

    pub struct Tensor_ {
        ...
        storage: Arc<RwLock<Storage>>,
        ...
    

    The contents of the storage that backs the tensor are protected by an RwLock. In fact there are some comments in the code immediately above this which describes the reason for the choice of this particular solution - worth a read.

    Not only that, but this is in turn wrapped in an Arc<T> - which means that it is in fact a heap allocated, reference counted value. There can be multiple 'owners' of this value, and it will only be deallocated when the last owner goes out of scope.

    How is this used in the case of backpropagation? Well, the backward() method of Tensor does not directly modify the tensor, rather it returns a GradStore containing the computed gradients. A GradStore may in turn be consumed by an Optimizer. Optimizer is a trait, with a couple of different implementations, so let's take a look at the SGD optimizer:

        fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
            for var in self.vars.iter() {
                if let Some(grad) = grads.get(var) {
                    var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
                }
            }
            Ok(())
        }
    

    OK, so the gradients are here being applied to some Var instances - what are these (defined here)?

    pub struct Var(Tensor);
    

    Ok, a wrapper around a Tensor. And how does the set method do it's job? This line is key:

    let (mut dst, layout) = self.storage_mut_and_layout();
    

    That gives us a mutable variable that seems to represent the destination for the set operation. What does this storage_mut_and_layout() method do?

    let storage = self.storage.write().unwrap();
    

    Ahah! It calls the write() method on the RwLock we saw above, inside which the storage lives. The documentation for this method says:

    Locks this RwLock with exclusive write access, blocking the current thread until it can be acquired.

    This function will not return while other writers or other readers currently have access to the lock.

    So in summary: