I'm mostly an outsider trying to understand if Rust is appropriate for my projects.
There are frameworks that do automatic differentiation in Rust. Specifically, candle, and some other projects, I think, somehow do it in a way that's similar to PyTorch, according to their description.
However, I know that Rust does not allow multiple mutable references. And it seems like that is what's needed for PyTorch-like automatic differentiation:
x = torch.rand(10) # an array of 10 elements
x.requires_grad = True
y = x.sin()
z = x**2
Both y
and z
must keep mutable references to x
, because you might want to backpropagate them, which will modify x.grad
. For example:
(y.dot(z)).backwards()
print(x.grad) # .backwards() adds a new field (an array) to x, without modifying it otherwise
So how can similar behavior be implemented in Rust, given that it does not allow multiple mutable references?
You are correct that the rust compiler enforces that there can only be one mutable reference to a value at a time, but there is an escape hatch: the interior mutability pattern.
This pattern allows programmers to construct data structures for which the rules are checked at run time instead of compile time.
The standard library provides a number of containers that implement interior mutability, with different usage patterns suitable for different scenarios. Key examples are:
RefCell<T>
, which allows run time borrow checking for single threaded usage
RwLock<T>
, which allows run time borrow checking for mutliple threaded usage
Mutex<T>
, which only allows one reference at a time to its contents
There are others - see the module level documentation for cell
and sync
.
How does this apply to candle? Let's take a peek under the hood:
pub struct Tensor_ {
...
storage: Arc<RwLock<Storage>>,
...
The contents of the storage that backs the tensor are protected by an RwLock
. In fact there are some comments in the code immediately above this which describes the reason for the choice of this particular solution - worth a read.
Not only that, but this is in turn wrapped in an Arc<T>
- which means that it is in fact a heap allocated, reference counted value. There can be multiple 'owners' of this value, and it will only be deallocated when the last owner goes out of scope.
How is this used in the case of backpropagation? Well, the backward()
method of Tensor
does not directly modify the tensor, rather it returns a GradStore
containing the computed gradients. A GradStore
may in turn be consumed by an Optimizer
. Optimizer
is a trait, with a couple of different implementations, so let's take a look at the SGD
optimizer:
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
}
}
Ok(())
}
OK, so the gradients are here being applied to some Var
instances - what are these (defined here)?
pub struct Var(Tensor);
Ok, a wrapper around a Tensor
. And how does the set
method do it's job? This line is key:
let (mut dst, layout) = self.storage_mut_and_layout();
That gives us a mutable variable that seems to represent the destination for the set operation. What does this storage_mut_and_layout()
method do?
let storage = self.storage.write().unwrap();
Ahah! It calls the write()
method on the RwLock
we saw above, inside which the storage lives. The documentation for this method says:
Locks this RwLock with exclusive write access, blocking the current thread until it can be acquired.
This function will not return while other writers or other readers currently have access to the lock.
So in summary:
backward()
method itself does not seem to modify the input Tensor
, but it returns a data structure containing the gradientsTensor
using an Optimizer
.Optimizer
uses the set
method to alter the Tensor
, which under the hood gets mutable access to the Tensor
's data storage using the write()
method on the RwLock
that is protecting it.