pythonmachine-learningpytorchgradient-descentstochastic-gradient

Can you affine warp a tensor while preserving gradient flow?


I'm trying to recreate the cv2.warpAffine() function, taking a tensor input and output rather than a Numpy array. However, gradients calculated from the output tensor produce a Non-None gradient tensor of all 0s. I've spent all night trying to see where I'm going wrong. If someone can spot my mistake, I will be so happy to hear what it is. I'm happy to provide any further information.

def tensor_warp_affine(tensor, M, output_shape, border_value=0.0):
    height, width = output_shape
    channels = tensor.shape[2]
    device = tensor.device
    dtype = tensor.dtype
   
   
    # Convert M to tensor and compute inverse
    M_tensor = torch.tensor(M, dtype=torch.float32, device=device)
    M_padded = torch.cat([M_tensor, torch.tensor([[0, 0, 1]], dtype=torch.float32, device=device)], dim=0)
    M_inv = torch.inverse(M_padded)[:2]
    # Create coordinate grids
    y_coords, x_coords = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device), indexing='ij')
    coords = torch.stack([x_coords.flatten(), y_coords.flatten(), torch.ones_like(x_coords.flatten())], dim=0).float()
    # Transform coordinates
    src_coords = M_inv @ coords
    src_x, src_y = src_coords[0], src_coords[1]
    # Compute interpolation weights
    x0 = src_x.long()
    y0 = src_y.long()
    x1 = x0 + 1
    y1 = y0 + 1
    x0 = torch.clamp(x0, 0, tensor.shape[1] - 1)
    x1 = torch.clamp(x1, 0, tensor.shape[1] - 1)
    y0 = torch.clamp(y0, 0, tensor.shape[0] - 1)
    y1 = torch.clamp(y1, 0, tensor.shape[0] - 1)
    dx = src_x - x0.float()
    dy = src_y - y0.float()
    # Perform bilinear interpolation
    wa = (1 - dx) * (1 - dy)
    wb = dx * (1 - dy)
    wc = (1 - dx) * dy
    wd = dx * dy
    img_flat = tensor.reshape(-1, channels)
   
    pixel = (wa[:, None] * img_flat[y0 * tensor.shape[1] + x0] +
             wb[:, None] * img_flat[y0 * tensor.shape[1] + x1] +
             wc[:, None] * img_flat[y1 * tensor.shape[1] + x0] +
             wd[:, None] * img_flat[y1 * tensor.shape[1] + x1])
   
    result = pixel.reshape(height, width, channels)
    return result

I'm doing this so I can pass the resulting tensor into an encoder model that generates latent vectors. I'm trying to optimise added perturbations to an input image using Stochastic gradient descent. My suspicion is that some of the transformation points are out of bounds, and being rounded to zero, is this non-differentiable?

I've tried printing out out-of-bounds values, but nothing ever flags up. I've tried using pytorche affine_grid but the results seem to be incorrect in my implementation. input tensor:

tensor([[[0.0000, 0.0078, 0.0078],
     [0.0039, 0.0118, 0.0118],
     [0.0000, 0.0078, 0.0078],
     ...,
     [0.0000, 0.0157, 0.0196],
     [0.0000, 0.0157, 0.0196],
     [0.0000, 0.0157, 0.0196]],

    [[0.0000, 0.0078, 0.0078],
     [0.0039, 0.0118, 0.0118],
     [0.0000, 0.0078, 0.0078],
     ...,
     [0.0000, 0.0157, 0.0196],
     [0.0000, 0.0157, 0.0196],
     [0.0000, 0.0157, 0.0196]],

    [[0.0000, 0.0078, 0.0078],
     [0.0000, 0.0078, 0.0078],
     [0.0000, 0.0078, 0.0078],
     ...,
     [0.0000, 0.0157, 0.0196],
     [0.0000, 0.0157, 0.0196],
     [0.0000, 0.0157, 0.0196]],

    ...,

    [[0.0275, 0.1059, 0.4392],
     [0.0196, 0.0980, 0.4314],
     [0.0157, 0.0941, 0.4275],
     ...,
     [0.3373, 0.5216, 0.8392],
     [0.3333, 0.5176, 0.8353],
     [0.3059, 0.4902, 0.8078]],

    [[0.0275, 0.1059, 0.4392],
     [0.0196, 0.0980, 0.4314],
     [0.0196, 0.0980, 0.4314],
     ...,
     [0.3373, 0.5216, 0.8392],
     [0.3216, 0.5059, 0.8235],
     [0.3059, 0.4902, 0.8078]],

    [[0.0275, 0.1059, 0.4392],
     [0.0196, 0.0980, 0.4314],
     [0.0196, 0.0980, 0.4314],
     ...,
     [0.3176, 0.5020, 0.8196],
     [0.3216, 0.5059, 0.8235],
     [0.3176, 0.5020, 0.8196]]], grad_fn=<SqueezeBackward1>)

The desired output tensor should look something like this:

tensor([[[8.5652e-03, 4.3859e-02, 9.7611e-02],
     [7.6292e-03, 4.6845e-02, 8.6061e-02],
     [0.0000e+00, 3.5453e-02, 7.4669e-02],
     ...,
     [3.9216e-03, 1.1765e-02, 5.0980e-02],
     [2.5561e-03, 1.7278e-02, 3.6886e-02],
     [0.0000e+00, 7.8431e-03, 4.7059e-02]],

    [[3.9216e-03, 3.9216e-02, 9.4118e-02],
     [7.7203e-04, 3.9988e-02, 7.9203e-02],
     [0.0000e+00, 3.7218e-02, 6.9977e-02],
     ...,
     [1.6439e-03, 8.9082e-03, 4.8124e-02],
     [2.9788e-03, 1.0822e-02, 5.0038e-02],
     [3.9216e-03, 1.1765e-02, 5.0980e-02]],

    [[0.0000e+00, 3.9216e-02, 6.6035e-02],
     [0.0000e+00, 3.1373e-02, 6.6667e-02],
     [0.0000e+00, 3.5294e-02, 3.5833e-02],
     ...,
     [0.0000e+00, 3.9216e-03, 4.3137e-02],
     [3.9216e-03, 1.0668e-02, 4.9884e-02],
     [4.4216e-04, 5.1212e-03, 4.4337e-02]],

    ...,

    [[1.4902e-01, 2.5882e-01, 3.7647e-01],
     [1.1612e-01, 2.2642e-01, 3.4406e-01],
     [9.3184e-02, 1.7629e-01, 2.6298e-01],
     ...,
     [8.1880e-02, 1.7951e-01, 3.2182e-01],
     [1.0127e-01, 1.7469e-01, 3.4608e-01],
     [2.0947e-02, 7.9770e-02, 2.1282e-01]],

    [[1.4533e-01, 2.6514e-01, 3.8574e-01],
     [1.2677e-01, 2.4847e-01, 3.7004e-01],
     [9.8382e-02, 1.9250e-01, 2.9153e-01],
     ...,
     [8.0654e-02, 1.5194e-01, 3.3930e-01],
     [1.8802e-01, 2.5493e-01, 4.4500e-01],
     [1.8050e-02, 6.2331e-02, 2.3506e-01]],

    [[1.4967e-01, 2.6551e-01, 3.8581e-01],
     [1.1913e-01, 2.3896e-01, 3.6051e-01],
     [9.7532e-02, 1.9963e-01, 2.9095e-01],
     ...,
     [5.5842e-02, 1.2251e-01, 3.1291e-01],
     [1.6402e-01, 2.2266e-01, 4.1669e-01],
     [8.3060e-02, 1.3796e-01, 3.3012e-01]]], grad_fn=<ViewBackward0>)

gradient calculation:

tensor([[[[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      ...,
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.]],

     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      ...,
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.]],

     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      ...,
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.]],

     ...,

     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      ...,
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.]],

     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      ...,
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.]],

     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      ...,
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.]]]])

Solution

  • Consider where your gradients are coming from. You want to backprop from your post-transformation result into your input image.

    This requires a direct computational link between your result and your input.

    Looking at the code, the input only participates in the computation at the final step:

        img_flat = tensor.reshape(-1, channels)
       
        pixel = (wa[:, None] * img_flat[y0 * tensor.shape[1] + x0] +
                 wb[:, None] * img_flat[y0 * tensor.shape[1] + x1] +
                 wc[:, None] * img_flat[y1 * tensor.shape[1] + x0] +
                 wd[:, None] * img_flat[y1 * tensor.shape[1] + x1])
       
        result = pixel.reshape(height, width, channels)
        return result
    

    That snippet is the only place where the input image is used. The input image does not participate in the computations of wa, ... wd or y0, y1, x0, x1. It only participates in the indexing operation.

    This means that the only gradient signal sent back to the input image will be at the index values selected in the above snippet. The specific pixels selected by img_flat[y0 * tensor.shape[1] + x0] and the other similar operations are the only pixel values that will have gradient. There is no way around this.

    Additionally, the vectors wa, ... wd have zero values. These will zero out additional gradients. Take for example wa[:, None] * img_flat[y0 * tensor.shape[1] + x0]. From that operation, we can only get gradient signal through the pixels defined by y0 * tensor.shape[1] + x0 where wa != 0.

    We can see this empirically:

    import torch
    import numpy as np
    
    torch.manual_seed(42)
    tensor = torch.rand(64, 64, 3, requires_grad=True)
    M = [[1, 0, 10], [0, 1, 20]]
    output_shape = (64, 64)
    
    # remove function to work with intermediates 
    height, width = output_shape
    channels = tensor.shape[2]
    device = tensor.device
    dtype = tensor.dtype
    
    
    # Convert M to tensor and compute inverse
    M_tensor = torch.tensor(M, dtype=torch.float32, device=device)
    M_padded = torch.cat([M_tensor, torch.tensor([[0, 0, 1]], dtype=torch.float32, device=device)], dim=0)
    M_inv = torch.inverse(M_padded)[:2]
    # Create coordinate grids
    y_coords, x_coords = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device), indexing='ij')
    coords = torch.stack([x_coords.flatten(), y_coords.flatten(), torch.ones_like(x_coords.flatten())], dim=0).float()
    # Transform coordinates
    src_coords = M_inv @ coords
    src_x, src_y = src_coords[0], src_coords[1]
    # Compute interpolation weights
    x0 = src_x.long()
    y0 = src_y.long()
    x1 = x0 + 1
    y1 = y0 + 1
    x0 = torch.clamp(x0, 0, tensor.shape[1] - 1)
    x1 = torch.clamp(x1, 0, tensor.shape[1] - 1)
    y0 = torch.clamp(y0, 0, tensor.shape[0] - 1)
    y1 = torch.clamp(y1, 0, tensor.shape[0] - 1)
    dx = src_x - x0.float()
    dy = src_y - y0.float()
    # Perform bilinear interpolation
    wa = (1 - dx) * (1 - dy)
    wb = dx * (1 - dy)
    wc = (1 - dx) * dy
    wd = dx * dy
    img_flat = tensor.reshape(-1, channels)
    
    pixel = (wa[:, None] * img_flat[y0 * tensor.shape[1] + x0] +
             wb[:, None] * img_flat[y0 * tensor.shape[1] + x1] +
             wc[:, None] * img_flat[y1 * tensor.shape[1] + x0] +
             wd[:, None] * img_flat[y1 * tensor.shape[1] + x1])
    
    result = pixel.reshape(height, width, channels)
    
    # example loss calculation to get gradient
    loss = result.mean()
    loss.backward()
    
    # get gradient
    grad = tensor.grad
    grad_nonzero = (grad != 0).sum()
    fraction_nonzero = grad_nonzero / grad.numel()
    
    print(f"Gradient has {grad_nonzero} nonzero values, {fraction_nonzero:.4f} percent of total values")
    
    # grab the index values used, ignore values where w_{i} == 0
    valid_indices = torch.cat([
                        (y0 * tensor.shape[1] + x0)[wa != 0],
                        (y0 * tensor.shape[1] + x1)[wb != 0],
                        (y1 * tensor.shape[1] + x0)[wc != 0],
                        (y1 * tensor.shape[1] + x1)[wd != 0],
                    ]).unique()
    
    total_valid_indices = valid_indices.shape[0]
    total_valid_grad_elements = total_valid_indices * 3 # pixel structure has 3 grad values per index
    print(f"Valid indices: {total_valid_indices}")
    print(f"Valid grad elements: {total_valid_grad_elements}")
    

    When I run this code it prints the following:

    Gradient has 7128 nonzero values, 0.5801 percent of total values
    Valid indices: 2376
    Valid grad elements: 7128
    

    We can see the gradient has 7128 nonzero values, which matches exactly the number computed from the unique indices used.

    With the M and output_shape parameters used at the start, there are only 2376 specific pixel indices from the input image that are propagated to the output image. This results in 7128 (2376*3) non-zero gradient values in the input tensor.

    This can be improved slightly by using some detach hacking to keep the gradients zeroed by the w_{i} arrays:

    w_list = [wa, wb, wc, wd]
    pixel_list = [
        img_flat[y0 * tensor.shape[1] + x0],
        img_flat[y0 * tensor.shape[1] + x1],
        img_flat[y1 * tensor.shape[1] + x0],
        img_flat[y1 * tensor.shape[1] + x1]
    ]
    
    pixel = []
    
    for i in range(len(w_list)):
        pixel.append((w_list[i][:,None] * pixel_list[i]).detach() + pixel_list[i] - pixel_list[i].detach())
        
    pixel = torch.stack(pixel).sum(0)
    

    This results in keeping all the gradients associated with pixel_list[i].

    For the parameters I used, this increases the nonzero gradient values from 7128 to 7425.

    I don't think that can be improved on. The other pixel values have zero gradient because they simply do not participate in the computation. You cannot compute a gradient from the result to pixel_{i} if pixel_{i} does not participate in the computation of the output.

    If you want there to be a gradient at all pixels, then all pixels need to participate in the computation of the output.