Long story short, I have a 2D matrix of ones and zeros and I need to retrieve, for each row, the indexes of the elements set to one. The “standard” way to do so would be torch.nonzero, but that function is well known for being 1) a real bottleneck, since it does not know in advance the size of the final vector, and 2) it cannot be applied to each row of a 2D tensor in one shot since different rows may have different amounts of ones.
Recently, at::nonzero_static has been introduced, which solves the first point by giving the function the expected maximum number of nonzero elements (which is fine for my application). However, it does not feature a “dim” argument, meaning that it cannot be applied to each row/column individually, which in my opinion makes no sense since setting the size of the output guarantees that each row would feature the same amount of items, thus making the output a tensor.
Using a for loop would obviously solve my issue, but that would mean calling the function several times which is not GPU efficient. Does anyone know a way to apply nonzero_static efficiently to each row, and returning a tensor where each row is the result of its application to each slice of the tensor? From my understanding, vmap may be a solution but I am not sure whether it is optimized for GPU.
I implemented a few solutions. A few preliminaries:
nonzero_static()
is unfortunately not compatible with cuda backend, which may be limiting for your use casevmap
will not likely work as it "does not provide general autobatching or handle variable-length sequences out of the box." and creates a batched_tensor output. Running vmap on nonzero_static
produces a warning UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::nonzero_static.
nonzero()
was the fastest or nearly as fast as the index-broadcasting solution. Seems that the unclear size of memory to be allocated is not in general a large bottleneck when compared to the relatively clunky workaround solutions. Would be interesting to re-evaluate if either nonzero_static
was optimized for batched computation with vmap
or CUDA backend was implemented for nonzero_static
which hopefully will eventually happen as it's a relatively new function in pytorch.import torch
import time
m = 2000
n = 1000
trials = 100
results = {}
for t in range(trials):
device = torch.device("cpu")
data = torch.rand([m,n],device = device).round().long()
# use nonzero
name = "nonzero"
t1 = time.time()
idx = data.nonzero()
midx = idx[:,0]
nidx = idx[:,1]
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use nonzero_static and leave in "listy" form
name = "nonzero_static"
t1 = time.time()
count_nonzero = int(data.sum().item())
d = data.view(-1)
idx = d.nonzero_static(size = count_nonzero)
midx,nidx = idx//n, idx%n
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use nonzero_static and put in matrix form, leave unsorted
name = "nonzero_static -> matrix"
t1 = time.time()
count_nonzero = int(data.sum().item())
d = data.view(-1)
idx = d.nonzero_static(size = count_nonzero)
midx,nidx = idx//n, idx%n
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use nonzero_static and put in matrix form, then sort
name = "nonzero_static -> sorted matrix"
t1 = time.time()
count_nonzero = int(data.sum().item())
d = data.view(-1)
idx = d.nonzero_static(size = count_nonzero)
midx,nidx = idx//n, idx%n
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# vmap nonzero_static
name = "vmap nonzero_static"
t1 = time.time()
test = torch.func.vmap(torch.nonzero_static)
output = test(data,size = n).squeeze(-1)
torch.cuda.synchronize()
torch.cuda.empty_cache()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use index broadcasting then sort
name = "index broadcasting"
t1 = time.time()
index_array = torch.arange(n).unsqueeze(0).expand(m,n)
output = data*index_array
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
device = torch.device("cuda:0")
data = data.to(device)
torch.cuda.synchronize()
# use index broadcasting then sort on GPU
name = "GPU index broadcasting"
t1 = time.time()
index_array = torch.arange(n,device = device).unsqueeze(0).expand(m,n)
output = data*index_array
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
del output
torch.cuda.empty_cache()
#use nonzero and leave in listy form
name = "GPU nonzero"
t1 = time.time()
idx = data.nonzero()
midx = idx[:,0]
nidx = idx[:,1]
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
print("Results for [{},{}] over {} trials".format(m,n,trials))
for key in results:
print("{:.5f}s for {}".format(results[key]/trials,key))
Results for [200,100] over 100 trials
0.00051s for nonzero
0.00035s for nonzero_static
0.00037s for nonzero_static -> matrix
0.00062s for nonzero_static -> sorted matrix
0.00191s for vmap nonzero_static
0.00033s for index broadcasting
0.00015s for GPU index broadcasting
0.00019s for GPU nonzero
Results for [2000,1000] over 100 trials
0.00575s for nonzero
0.01028s for nonzero_static
0.01036s for nonzero_static -> matrix
0.01302s for nonzero_static -> sorted matrix
0.03645s for vmap nonzero_static
0.00466s for index broadcasting
0.00129s for GPU index broadcasting
0.00198s for GPU nonzero
Results for [20000,10000] over 20 trials
0.67861s for nonzero
1.10534s for nonzero_static
1.31800s for nonzero_static -> matrix
1.66106s for nonzero_static -> sorted matrix
2.68011s for vmap nonzero_static
0.55859s for index broadcasting
0.31346s for GPU index broadcasting
0.30350s for GPU nonzero