
How to perform operations on very big torch tensors without splitting them

My Task:

I'm trying to calculate the pair-wise distance between every two samples in two big tensors (for k-Nearest-Neighbours), That is - given tensor test with shape (b1,c,h,w) and tensor train with shape (b2,c,h,w), I need || test[i]-train[j] || for every i,j. (where both test[i] and train[j] have shape (c,h,w), as those are sampes in the batch).

The Problem

both train and test are very big, so I can't fit them into RAM

My current solution

For a start, I did not construct these tensors in one go - As I build them, I split the data Tensor and save them separately to memory, so I end up with files {Test\test_1,...,Test\test_n} and {Train\train_1,...,Train\train_m}. Then, I load in a nested for loop every Test\test_i and Train\train_j, calculate the current distance, and save it.

This semi-pseudo-code might explain

test_files = [f'Test\test_{i}' for i in range(n)]
train_files = [f'Train\train_{j}' for j in range(m)]
dist = lambda t1,t2: torch.cdist(t1.flatten(1), t2.flatten(1)) 
all_distances = []
for test_i in test_files:
    test_i = torch.load(test_i) # shape (c,h,w)
    dist_of_i_from_all_j = torch.Tensor([])
    for train_j in train_files:
        train_j = torch.load(train_j) # shape (c,h,w)
        dist_of_i_from_all_j =, dist(test_i,train_j))
# and now I can take the k-smallest from all_distances

What I thought might work

I came across FAISS repository, in which they explain that this process can be sped up (maybe?) using their solutions, though I'm not quite sure how. Regardless, any approach would help!


  • Did you check the FAISS documentation?

    If what you need is the L2 norm (torch.cidst uses p=2 as default parameter) then it is quite straightforward. Code below is an adaptation of the FAISS docs to your example:

    import faiss
    import numpy as np
    d = 64                           # dimension
    nb = 100000                      # database size
    nq = 10000                       # nb of queries
    np.random.seed(1234)             # make reproducible
    x_test = np.random.random((nb, d)).astype('float32')
    x_test[:, 0] += np.arange(nb) / 1000.
    x_train = np.random.random((nq, d)).astype('float32')
    x_train[:, 0] += np.arange(nq) / 1000.
    index = faiss.IndexFlatL2(d)   # build the index
    index.add(x_test)                  # add vectors to the index
    k= 100 # take the 100 closest neighbors
    D, I =, k)     # actual search
    print(I[:5])                   # neighbors of the 100 first queries
    print(I[-5:])                  # neighbors of the 100 last queries