pythontensorflowpytorchpaddingconv1d

How padding=zeros works in pytorch in functional.conv1d


This following code below giving a output of shape (1,1,3) for the shape of xodd is (1,1,2). The given kernel shape is(112, 1, 1).

from torch.nn import functional as F
output = F.conv1d(xodd, kernel, padding=zeros)

How the padding=zeros works?
And also, How can I write an equivalent code in tensorflow so that the output is as same as the above output?


Solution

  • What is padding=zeros? If we set paddin=zeros, we don't need to add numbers at the right and the left of the tensor.

    Padding=0:

    from torch.nn import functional as F
    import torch
    inputs = torch.randn(33, 16, 6) # (minibatch,in_channels,features)
    filters = torch.randn(20, 16, 5) # (out_channels, in_channels, kernel_size)
    out_tns = F.conv1d(inputs, filters, stride=1, padding=0)
    print(out_tns.shape)
    # torch.Size([33, 20, 2]) # (minibatch,out_channels,(features-kernel_size+1))
    

    enter image description here

    Padding=2:(We want to add two numbers at the right and the left of the tensor)

    inputs = torch.randn(33, 16, 6) # (minibatch,in_channels,features)
    filters = torch.randn(20, 16, 5) # (out_channels, in_channels, kernel_size)
    out_tns = F.conv1d(inputs, filters, stride=1, padding=2)
    print(out_tns.shape)
    # torch.Size([33, 20, 6]) # (minibatch,out_channels,(features-kernel_size+1+2+2))
    

    enter image description here

    How can I write an equivalent code in tensorflow:

    import tensorflow as tf
    input_shape = (33, 6, 16)
    x = tf.random.normal(input_shape)
    out_tf = tf.keras.layers.Conv1D(filters = 20, 
                                    kernel_size = 5,
                                    strides = 1, 
                                    input_shape=input_shape[1:])(x)
    print(out_tf.shape)
    # TensorShape([33, 2, 20])
    
    # If you want that tensor have shape exactly like pytorch you can transpose
    tf.transpose(out_tf, [0, 2, 1]).shape
    # TensorShape([33, 20, 2])