pythonpytorchtensor

Pytorch: how to (efficiently) apply a function without a “dim” argument to each row of a 2D tensor?


Long story short, I have a 2D matrix of ones and zeros and I need to retrieve, for each row, the indexes of the elements set to one. The “standard” way to do so would be torch.nonzero, but that function is well known for being 1) a real bottleneck, since it does not know in advance the size of the final vector, and 2) it cannot be applied to each row of a 2D tensor in one shot since different rows may have different amounts of ones.

Recently, at::nonzero_static has been introduced, which solves the first point by giving the function the expected maximum number of nonzero elements (which is fine for my application). However, it does not feature a “dim” argument, meaning that it cannot be applied to each row/column individually, which in my opinion makes no sense since setting the size of the output guarantees that each row would feature the same amount of items, thus making the output a tensor.

Using a for loop would obviously solve my issue, but that would mean calling the function several times which is not GPU efficient. Does anyone know a way to apply nonzero_static efficiently to each row, and returning a tensor where each row is the result of its application to each slice of the tensor? From my understanding, vmap may be a solution but I am not sure whether it is optimized for GPU.


Solution

  • I implemented a few solutions. A few preliminaries:

    import torch
    import time
    m = 2000
    n = 1000
    trials = 100
    
    results = {}
    for t in range(trials):
        
        device = torch.device("cpu")
        data = torch.rand([m,n],device = device).round().long()
        
        # use nonzero 
        name = "nonzero"
        t1 = time.time()
        idx = data.nonzero()
        midx = idx[:,0]
        nidx = idx[:,1]
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        output = output.sort(dim = 1,descending = True)
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        
        # use nonzero_static and leave in "listy" form
        name = "nonzero_static"
        t1 = time.time()
        count_nonzero = int(data.sum().item())
        d = data.view(-1)
        idx = d.nonzero_static(size = count_nonzero)
        midx,nidx = idx//n, idx%n
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        # use nonzero_static and put in matrix form, leave unsorted
        name = "nonzero_static -> matrix"
        t1 = time.time()
        count_nonzero = int(data.sum().item())
        d = data.view(-1)
        idx = d.nonzero_static(size = count_nonzero)
        midx,nidx = idx//n, idx%n
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        
        
        # use nonzero_static and put in matrix form, then sort
        name = "nonzero_static -> sorted matrix"
        t1 = time.time()
        count_nonzero = int(data.sum().item())
        d = data.view(-1)
        idx = d.nonzero_static(size = count_nonzero)
        midx,nidx = idx//n, idx%n
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        output = output.sort(dim = 1,descending = True)
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        
        # vmap nonzero_static
        name = "vmap nonzero_static"
        t1 = time.time()
        test = torch.func.vmap(torch.nonzero_static)
        output = test(data,size = n).squeeze(-1)
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        # use index broadcasting then sort
        name = "index broadcasting"
        t1 = time.time()
        index_array = torch.arange(n).unsqueeze(0).expand(m,n)
        output = data*index_array
        output = output.sort(dim = 1,descending = True)
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        
        
        device = torch.device("cuda:0")
        data = data.to(device)
        torch.cuda.synchronize()
        
        # use index broadcasting then sort on GPU
        name = "GPU index broadcasting"
        t1 = time.time()
        index_array = torch.arange(n,device = device).unsqueeze(0).expand(m,n)
        output = data*index_array
        output = output.sort(dim = 1,descending = True)
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
        
        del output
        torch.cuda.empty_cache()
        
        #use nonzero and leave in listy form
        name = "GPU nonzero"
        t1 = time.time()
        idx = data.nonzero()
        midx = idx[:,0]
        nidx = idx[:,1]
        output = torch.zeros([m,n],device = device,dtype = torch.long)
        output[midx,nidx] = nidx 
        output = output.sort(dim = 1,descending = True)
        
        torch.cuda.synchronize()
        try:
            results[name] += time.time()- t1
        except:
            results[name] = time.time() - t1
            
    print("Results for [{},{}] over {} trials".format(m,n,trials))
    for key in results:
        print("{:.5f}s for {}".format(results[key]/trials,key))
       
    
    
    Results for [200,100] over 100 trials
    0.00051s for nonzero
    0.00035s for nonzero_static
    0.00037s for nonzero_static -> matrix
    0.00062s for nonzero_static -> sorted matrix
    0.00191s for vmap nonzero_static
    0.00033s for index broadcasting
    0.00015s for GPU index broadcasting
    0.00019s for GPU nonzero
    
    Results for [2000,1000] over 100 trials
    0.00575s for nonzero
    0.01028s for nonzero_static
    0.01036s for nonzero_static -> matrix
    0.01302s for nonzero_static -> sorted matrix
    0.03645s for vmap nonzero_static
    0.00466s for index broadcasting
    0.00129s for GPU index broadcasting
    0.00198s for GPU nonzero
    
    Results for [20000,10000] over 20 trials
    0.67861s for nonzero
    1.10534s for nonzero_static
    1.31800s for nonzero_static -> matrix
    1.66106s for nonzero_static -> sorted matrix
    2.68011s for vmap nonzero_static
    0.55859s for index broadcasting
    0.31346s for GPU index broadcasting
    0.30350s for GPU nonzero