tensorflowpytorchonnxvitis-ai

Tensor format issue from converting Pytorch -> Onnx -> Tensorflow


I have an issue with Tensorflow model that is converted from Pytorch -> Onnx -> Tensorflow. The issue is the converted Tensorflow model expects the input in Pytorch format that is (batch size, number channels, height, width) but not in Tensorflow format (batch size, height, width, number channel). Therefore, I cannot use the model to process further with Vitis AI.

So I would like to ask is there is any ways to convert this Pytorch input format to Tensorflow format by using tools from Onnx, Tensorflow 1, or others?

My code is as below:

Pytorch -> Onnx

from hardnet import hardnet
import torch
import onnx

ckpt = torch.load('../hardnet.pth')
model_state_dict = ckpt['model_state_dict']
optimizer_state_dict = ckpt['optimizer_state_dict']

model = hardnet(11)
model.load_state_dict(model_state_dict)
model.eval()     

dummy_input = torch.randn(1, 3, 1080, 1920)
input_names = ['input0']
output_names = ['output0']

output_file = 'hardnet.onnx'
torch.onnx.export(model, dummy_input, output_file, verbose=True,
    input_names=input_names, output_names=output_names,
    opset_version=11, keep_initializers_as_inputs=True)

onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
print('Passed Onnx')

Onnx -> Tensorflow 1 (using Tensorflow 1.15)

import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import onnx
from onnx_tf.backend import prepare

output_file = 'hardnet.onnx'
onnx_model = onnx.load(output_file)
output = prepare(onnx_model)
output.export_graph('hardnet.pb')
tf.compat.v1.disable_eager_execution()

def load_pb(path_to_pb: str):
    """From: https://stackoverflow.com/questions/51278213/what-is-the-use-of-a-pb-file-in-tensorflow-and-how-does-it-work
    """
    with tf.gfile.GFile(path_to_pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


graph = load_pb('hardnet.pb')
input = graph.get_tensor_by_name('input0:0')
output = graph.get_tensor_by_name('output0:0')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = cv2.imread('train_0.jpg', cv2.IMREAD_COLOR)
img = cv2.resize(img, (1920,  1080))

img = img/255
img = img - mean
img = img/std
img = np.expand_dims(img, -1)
# To Pytorch format.
img = np.transpose(img, (3, 2, 0, 1))
img = img

with tf.Session(graph=graph) as sess:
    pred = sess.run(output, {input: img})

Solution

  • You could wrap your Pytorch model into another one that would do the transpose you want to have in TensorFlow. See the following example:

    Let's say you have the following toy NN:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.rnn = nn.LSTM(10, 20, 2)
    
        def forward(self, x):
            h0 = torch.zeros(2, 3, 20)
            c0 = torch.zeros(2, 3, 20)
            return self.rnn(x, (h0, c0))
    

    the exemplary pytorch/tensorflow input shape would be :

    >> pytorch_input  = torch.randn(5, 3, 10)
    >> tf_input  = torch.transpose(pytorch_input, 1, 2)
    
    >> print("PyTorch input shape: ", pytorch_input.shape)
    >> print("TensorFlow input shape: ", tf_input.shape)
    
    PyTorch input shape:  torch.Size([5, 3, 10])
    TensorFlow input shape:  torch.Size([5, 10, 3])
    

    Now, the wrapper which will first transpose input and then pass transposed input to some model:

    class NetTensorFlowWrapper(nn.Module):
        def __init__(self, main_module: nn.Module):
            super(NetTensorFlowWrapper, self).__init__()
            self.main_module = main_module
            
        def forward(self, x):
            x = torch.transpose(x, 1, 2)
            return self.main_module(x)
    

    Then, this is possible:

    net = Net()
    net_wrapper = NetTensorFlowWrapper(net)
    
    net(pytorch_input)
    net_wrapper(tf_input)
    

    and then, when you finally save your models like you did previously via torch.onnx.export and read their graph via onnx package (not torch.onnx) you will have...

    graph torch-jit-export (
      %input0[FLOAT, 5x3x10]
    
     {
      %76 = Shape(%input0)
      %77 = Constant[value = <Scalar Tensor []>]()
    
    graph torch-jit-export (
      %input0[FLOAT, 5x10x3]
    
    {
      %9 = Transpose[perm = [0, 2, 1]](%input0)
      %77 = Shape(%9)
      %78 = Constant[value = <Scalar Tensor []>]()
    ...