pythonpytorchsizetensor

How to use `stack()` in PyTorch?


How do I use torch.stack() to stack two tensors with shapes a.shape = (2, 3, 4) and b.shape = (2, 3) without an in-place operation?


Solution

  • Stacking requires same number of dimensions. One way would be to unsqueeze and stack. For example:

    a.size()  # 2, 3, 4
    b.size()  # 2, 3
    b = torch.unsqueeze(b, dim=2)  # 2, 3, 1
    # torch.unsqueeze(b, dim=-1) does the same thing
    
    torch.stack([a, b], dim=2)  # 2, 3, 5