pythonpytorchcopytensor

PyTorch preferred way to copy a tensor


There seems to be several ways to create a copy of a tensor in PyTorch, including

y = tensor.new_tensor(x) #a

y = x.clone().detach() #b

y = torch.empty_like(x).copy_(x) #c

y = torch.tensor(x) #d

b is explicitly preferred over a and d according to a UserWarning I get if I execute either a or d. Why is it preferred? Performance? I'd argue it's less readable.

Any reasons for/against using c?


Solution

  • TL;DR

    Use .clone().detach() (or preferrably .detach().clone())

    If you first detach the tensor and then clone it, the computation path is not copied, the other way around it is copied and then abandoned. Thus, .detach().clone() is very slightly more efficient.-- pytorch forums

    as it's slightly fast and explicit in what it does.


    Using perfplot, I plotted the timing of various methods to copy a pytorch tensor.

    y = tensor.new_tensor(x) # method a
    
    y = x.clone().detach() # method b
    
    y = torch.empty_like(x).copy_(x) # method c
    
    y = torch.tensor(x) # method d
    
    y = x.detach().clone() # method e
    

    The x-axis is the dimension of tensor created, y-axis shows the time. The graph is in linear scale. As you can clearly see, the tensor() or new_tensor() takes more time compared to other three methods.

    enter image description here

    Note: In multiple runs, I noticed that out of b, c, e, any method can have lowest time. The same is true for a and d. But, the methods b, c, e consistently have lower timing than a and d.

    import torch
    import perfplot
    
    perfplot.show(
        setup=lambda n: torch.randn(n),
        kernels=[
            lambda a: a.new_tensor(a),
            lambda a: a.clone().detach(),
            lambda a: torch.empty_like(a).copy_(a),
            lambda a: torch.tensor(a),
            lambda a: a.detach().clone(),
        ],
        labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
        n_range=[2 ** k for k in range(15)],
        xlabel="len(a)",
        logx=False,
        logy=False,
        title='Timing comparison for copying a pytorch tensor',
    )