tensorflowpytorchautogradgradienttape

Calculating gradients in Custom training loop, difference in performace TF vs Torch


I have attempted to translate pytorch implementation of a NN model which calculates forces and energies in molecular structures to TensorFlow. This needed a custom training loop and custom loss function so I implemented to different one step training functions below.

  1. First using Nested Gradient Tapes.
def calc_gradients(D_train_batch, E_train_batch, F_train_batch, opt):
    
    #set up gradient tape scope in order to track gradients of both d(Loss)/d(Weights)
    #and d(output)/d(input)
     with tf.GradientTape() as tape1:
          with tf.GradientTape() as tape2:
              #set gradient tape to watch Tensor
              tape2.watch(D_train_batch)
              #pass D thru model to get predicted energy vals
              E_pred = model(D_train_batch, training=True)
                  
          df_dD_train_batch = tape2.gradient(E_pred, D_train_batch) 
          #matrix mult of -Grad_D(f) x Grad_r(D)
          F_pred = -tf.einsum('ijkl,il->ijk', dD_dr_train_batch, df_dD_train_batch)
          #calculate loss value
          loss = force_energy_loss(E_pred, F_pred, E_train_batch, F_train_batch)
          
          
     
     grads = tape1.gradient(loss, model.trainable_weights)
     opt.apply_gradients(zip(grads, model.trainable_weights))
  1. Other attempt with gradient tape (persistent = true)
def calc_gradients_persistent(D_train_batch, E_train_batch, F_train_batch, opt):
#set up gradient tape scope in order to track gradients of both d(Loss)/d(Weights)
        #and d(output)/d(input)
        with tf.GradientTape(persistent = True) as outer:
            
            #set gradient tape to watch Tensor
            outer.watch(D_train_batch)
            
            #output values from model, set trainable to be true to get 
            #model.trainable_weights out
            E_pred = model(D_train_batch, training=True)
            
            #set gradient tape to watch trainable weights
            outer.watch(model.trainable_weights)
            
            #get gradient of output (f/E_pred) w.r.t input (D/D_train_batch) and cast to double
            df_dD_train_batch = outer.gradient(E_pred, D_train_batch)
            
            #matrix mult of -Grad_D(f) x Grad_r(D)
            F_pred = -tf.einsum('ijkl,il->ijk', dD_dr_train_batch, df_dD_train_batch)

            #calculate loss value
            loss = force_energy_loss(E_pred, F_pred, E_train_batch, F_train_batch)
        
        #get gradient of loss w.r.t to trainable weights for back propogation
        grads = outer.gradient(loss, model.trainable_weights)
        
        #updates weights using the optimizer and the gradients (grads)
        opt.apply_gradients(zip(grads, model.trainable_weights)) 

These were attempted translations of the pytorch code

# Forward pass: Predict energies from the descriptor input
        E_train_pred_batch = model(D_train_batch)

        # Get derivatives of model output with respect to input variables. The
        # torch.autograd.grad-function can be used for this, as it returns the
        # gradients of the input with respect to outputs. It is very important
        # to set the create_graph=True in this case. Without it the derivatives
        # of the NN parameters with respect to the loss from the force error
        # will not be populated (=the force error will not affect the
        # training), but the model will still run fine without errors.
        df_dD_train_batch = torch.autograd.grad(
            outputs=E_train_pred_batch,
            inputs=D_train_batch,
            grad_outputs=torch.ones_like(E_train_pred_batch),
            create_graph=True,
        )[0]

        # Get derivatives of input variables (=descriptor) with respect to atom
        # positions = forces
        F_train_pred_batch = -torch.einsum('ijkl,il->ijk', dD_dr_train_batch, df_dD_train_batch)

        # Zero gradients, perform a backward pass, and update the weights.
        # D_train_batch.grad.data.zero_()
        optimizer.zero_grad()
        loss = energy_force_loss(E_train_pred_batch, E_train_batch, F_train_pred_batch, F_train_batch)
        loss.backward()
        optimizer.step()

which is from the tutorial for the Dscribe library at https://singroup.github.io/dscribe/latest/tutorials/machine_learning/forces_and_energies.html

Question

Using either versions of the TF implementation there is a huge loss in prediction accuracy compared to running the pytorch version. I was wondering, have I maybe misunderstood the pytorch code and translated incorrectly and if so where is my discrepancy?

P.S Model directly computes energies E, from which we use the gradient of E w.r.t D in order to calculate the forces F. The loss function is a weighted sum of MSE of both Force and energies.


Solution

  • These methods are in fact the same, my error was somewhere else which was creating differing results. For anyone whose trying to implement the TensorFlow versions, the nested gradient tapes are about 2x faster, at least in this scenario and also ensure to wrap the functions in an @tf.function in order to use graphs over eager execution, The speed up is about 10x.