Is it possible to parallelize a (natively) single batch model?
Usually parallelization is done via the torch.bmm (batched matrix multiplication) in stead of the torch.matmul and fixing one dimension specifically for the batches. However this is not available for example for the torch.tensordot function.
So if one has such a model, is it possible to compute each gradient of the batch in parallel? Ideally, the parallelization should work with both training and inference.
A code example:
import torch
import torch.nn as nn
class LinearMultidimModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearMultidimModel, self).__init__()
self.weight = nn.Parameter(torch.randn(input_dim, hidden_dim, output_dim))
self.bias = nn.Parameter(torch.randn(output_dim))
def forward(self, x):
# Using torch.tensordot to perform the linear transformation
out = torch.tensordot(x, self.weight, dims=[[0,1],[0,1]]) + self.bias
return out
# Example usage
input_dim = 3
hidden_dim=2
output_dim = 1
model = LinearMultidimModel(input_dim, output_dim)
# Dummy input
x = torch.randn(input_dim, hidden_dim)# But what if I want to put in a batch, torch.randn(batch_size, input_dim, hidden_dim)?
output = model(x)
print(output)
Keep in mind that if there is no hidden_dim, it natively does the parallelization, one can entirely remove hidden_dim and get a result with
x = torch.randn(5, input_dim)
.
I've tried using Einsum, but that works for a fixed amount of hidden dimensions...
You can use torch.vmap for this exact purpose
class LinearMultidimModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearMultidimModel, self).__init__()
self.weight = nn.Parameter(torch.randn(input_dim, hidden_dim, output_dim))
self.bias = nn.Parameter(torch.randn(output_dim))
def forward(self, x):
# Using torch.tensordot to perform the linear transformation
out = torch.tensordot(x, self.weight, dims=[[0,1],[0,1]]) + self.bias
return out
input_dim = 3
hidden_dim=2
output_dim = 1
model = LinearMultidimModel(input_dim, output_dim)
# create input with batch dimension
batch_size = 8
x = torch.randn(batch_size, input_dim, hidden_dim)
# example unbatched inference
y1 = torch.stack([model(i) for i in x])
# vmap model to make it batched
model_batched = torch.func.vmap(model)
# batch inference
y2 = model_batched(x)
# assert outputs are the same
assert torch.allclose(y1, y2)