from jax import random,vmap
from jax import numpy as jnp
import pprint
def f(s,layers,do,dx):
x = jnp.zeros((do,dx))
for i,layer in enumerate(layers):
x=x.at[i].set( layer( s[i] ) )
return x
class net(nn.Module):
dx: int
do: int
def setup(self):
self.layers = [ nn.Dense( self.dx, use_bias=False )
for _ in range(self.do) ]
def __call__(self, s):
x = vmap(f,in_axes=(0,None,None,None))(s,self.layers,self.do,self.dx)
return x
if __name__ == '__main__':
seed = 123
key = random.PRNGKey( seed )
key,subkey = random.split( key )
outer_batches = 4
s_observations = 5 # AKA the inner batch
x_features = 2
s_features = 3
s_shape = (outer_batches,s_observations, s_features)
s = random.uniform( subkey, s_shape )
key,subkey = random.split( key )
model = net(x_features,s_observations)
p = model.init( subkey, s )
x = model.apply( p, s )
params = p['params']
pkernels = jnp.array([params[key]['kernel'] for key in params.keys()])
x_=jnp.zeros((outer_batches,s_observations,x_features))
g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
x_=g(s,pkernels)
print('s shape:',s.shape)
print('p shape:',pkernels.shape)
print('x shape:',x.shape)
print('x_ shape:',x_.shape)
print('sum of difference:',jnp.sum(x-x_))
Hi. I need some "batch-specific" parameters in my model. Here, there is an "inner batch" of length do
such that there is a flax.linen.Dense
instance for each element in that batch. The outer batch just passes multiple data instances into those layers. I accomplish this by creating a list of flax.linen.Dense
instances in the setup
method. Then in the __call__
method, I iterate over those layers to fill up an array. This iteration is encapsulated in a function f
, and that function is wrapped in jax.vmap
.
I have also included some equivalent logic written as matrix multiplication (see the function g
) to make it explicit what operation I was hoping to capture with this class.
I would like to replace the for-loop in the __call__
method with a call to jax.vmap
. I ofc get an error when I pass a list to vmap
, and I ofc get an error when I try to put multiple Dense
instances in a jax array. Is there an alternative to using a list to contain my multiple Dense
instances? A constraint is that I should be able to create an arbitrary number of Dense
instances at the time of model initialization.
vmap
can be used to map a single function over batches of data. You are attempting to use it to map multiple functions over batches of data, which it cannot do.
Updated answer based on updated question:
Since each layer is identical aside from the parameters fit to the input data, it sounds like what you want is to map a single dense layer against a batch of data. It might look something like this:
keys = vmap(random.fold_in, in_axes=(None, 0))(subkey, jnp.arange(s_observations))
model = nn.Dense(x_features, use_bias=False)
p = vmap(model.init, in_axes=(0, 1))(keys, s)
x = vmap(model.apply, in_axes=(0, 1), out_axes=1)(p, s)
pkernels = p['params']['kernel']
g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
x_=g(s,pkernels)
print('sum of difference:',jnp.sum(x-x_))
# sum of difference: 0.0
Previous answer
In general, the fix would be to define a single parameterized layer that you can pass to vmap
. In the example you gave, every layer is identical, and so to achieve the result you're looking for you could write something like this:
def f(s,layer,dx):
return layer(s)
class net(nn.Module):
dx: int
do: int
def setup(self):
self.layer = nn.Dense( self.dx, use_bias=False )
def __call__(self, s):
x = vmap(f,in_axes=(0,None,None))(s,self.layer,self.dx)
return x
If you had different parameterization per layer, then you could achieve this within vmap
by passing those parameters to vmap
as well.