I have been trying to understand this official example. However, I am very confused about the use of partial
in two places.
For example, in line 94, we have the following:
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
I am not sure why it is possible to apply a partial to a class, and where later in the code we fill in the missing argument (if we need to).
Coming to the final definition, I am even more confused. For example,
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
block_cls=ResNetBlock)
Where do we apply the argument such as stage_size=[2,2,2,2]
?
Thank you
functools.partial
will partially evaluate a function, binding arguments to it for when it is called later. here's an example of it being used with a function:
from functools import partial
def f(x, y, z):
print(f"{x=} {y=} {z=}")
g = partial(f, 1, z=3)
g(2)
# x=1 y=2 z=3
and here is an example of it being used on a class constructor:
from typing import NamedTuple
class MyClass(NamedTuple):
a: int
b: int
c: int
make_class = partial(MyClass, 1, c=3)
print(make_class(b=2))
# MyClass(a=1, b=2, c=3)
The use in the flax example is conceptually the same: partial(f)
returns a function that when called, applies the bound arguments to the original callable, whether it is a function, a method, or a class constructor.
For example, the ResNet18
function created here:
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
block_cls=ResNetBlock)
is a partially-evaluated ResNet
constructor, and the function is called in a test here:
@parameterized.product(
model=(models.ResNet18, models.ResNet18Local)
)
def test_resnet_18_v1_model(self, model):
"""Tests ResNet18 V1 model definition and output (variables)."""
rng = jax.random.PRNGKey(0)
model_def = model(num_classes=2, dtype=jnp.float32)
variables = model_def.init(
rng, jnp.ones((1, 64, 64, 3), jnp.float32))
self.assertLen(variables, 2)
self.assertLen(variables['params'], 11)
model
here is the partially evaluated function ResNet18
, and when it is called it returns the fully-instantiated ResNet
object with the parameters specified in the ResNet18
partial definition.