pythonpytorchjagged-arraysragged

What's the workaround for "ragged/jagged tensors" in PyTorch?


Tensorflow provides ragged tensors (https://www.tensorflow.org/guide/ragged_tensor). PyTorch however doesn't provide such a data structure. Is there a workaround to construct something similar in PyTorch?

import numpy as np
x = np.array([[0], [0, 1]])
print(x)  # [list([0]) list([0, 1])]

import tensorflow as tf
x = tf.ragged.constant([[0], [0, 1]])
print(x)  # <tf.RaggedTensor [[0], [0, 1]]>

import torch
# x = torch.Tensor([[0], [0, 1]])  # ValueError

Solution

  • PyTorch is implementing something called NestedTensors which seems to have pretty much the same purpose as RaggedTensors in Tensorflow. You can follow the RFC and progress here.