Edit: apparently DGL is working on it already: https://github.com/dmlc/dgl/pull/3641
I have several types of embeddings and each one needs its own linear projection. I can solve the problem with a for loop of type:
emb_out = dict()
for ntype in ntypes:
emb_out[ntype] = self.lin_layer[ntype](emb[ntype])
But ideally, I wanted to do some sort of scatter operation to run it in parallel. Something like:
pytorch_scatter(lin_layers, embeddings, layer_map, reduce='matmul')
, where the layer map tells which embedding should go through which layer. If I have 2 types of linear layers and batch_size = 5, then layer_map would be something like [1,0,1,1,0].
Would it be possible to vectorize the for loop in a efficient way, like in pytorch_scatter? Please check below minimal examples.
import torch
import random
import numpy as np
seed = 42
torch.manual_seed(seed)
random.seed(seed)
def matmul_single_embtype(lin_layers, embeddings, layer_map):
#run single linear layer over all embeddings, irrespective of type
output_embeddings = torch.matmul(lin_layers[0], embeddings.T).T
return output_embeddings
def matmul_for_loop(lin_layers, embeddings, layer_map):
#let each embedding type have its own projection, looping over emb types
output_embeddings = dict()
for emb_type in np.unique(layer_map):
output_embeddings[emb_type] = torch.matmul(lin_layers[emb_type], embeddings[layer_map == emb_type].T).T
return output_embeddings
def matmul_scatter(lin_layers, embeddings, layer_map):
#parallelize the for loop by creating a diagonal matrix of lin layers
#this is very innefficient, because creates a copy of the layer for each embedding, instead of broadcasting
mapped_lin_layers = [lin_layers[i] for i in layer_map]
mapped_lin_layers = torch.block_diag(*mapped_lin_layers) #batch_size*inp_size x batch_size*output_size
embeddings_stacked = embeddings.view(-1,1) #stack all embeddings to multiply the linear block
output_embeddings = torch.matmul(mapped_lin_layers, embeddings_stacked).view(embeddings.shape)
return output_embeddings
"""
GENERATE DATA
lin_layers:
List of matrices of size n_layer x inp_size x output_size
embeddings:
Matrix of size batch_size x inp_size
layer_map:
Vector os size batch_size stating which embedding should go thorugh each layer
"""
emb_size = 32
batch_size = 500
emb_types = 20
layer_map = [random.choice(list(range(emb_types))) for i in range(batch_size)]
lin_layers = [torch.arange(emb_size*emb_size, dtype=torch.float32).view(emb_size,emb_size) for i in range(emb_types)]
embeddings = torch.arange(batch_size*emb_size, dtype=torch.float32).view(batch_size,emb_size)
grouped_emb = {i: embeddings[layer_map==i] for i in np.unique(layer_map)} #separate embeddings by embedding type
#Run experiments
%timeit matmul_scatter(lin_layers, embeddings, layer_map)
%timeit matmul_for_loop(lin_layers, embeddings, layer_map)
%timeit matmul_single_embtype(lin_layers, embeddings, layer_map)
>>>>>133 ms ± 2.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>>>>1.64 ms ± 14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>>>>31.4 µs ± 805 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Related stackoverflow question: how to vectorize the scatter-matmul operation
Related issue in pytorch: https://github.com/pytorch/pytorch/issues/31942
Just found out that DGL is working on this feature already: https://github.com/dmlc/dgl/pull/3641