jaxflax

Fail to understand the usage of partial argument in Flax Resnet Official Example


I have been trying to understand this official example. However, I am very confused about the use of partial in two places.

For example, in line 94, we have the following:

conv = partial(self.conv, use_bias=False, dtype=self.dtype)

I am not sure why it is possible to apply a partial to a class, and where later in the code we fill in the missing argument (if we need to).

Coming to the final definition, I am even more confused. For example,

ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
               block_cls=ResNetBlock)

Where do we apply the argument such as stage_size=[2,2,2,2]?

Thank you


Solution

  • functools.partial will partially evaluate a function, binding arguments to it for when it is called later. here's an example of it being used with a function:

    from functools import partial
    
    def f(x, y, z):
      print(f"{x=} {y=} {z=}")
    
    g = partial(f, 1, z=3)
    g(2)
    # x=1 y=2 z=3
    

    and here is an example of it being used on a class constructor:

    from typing import NamedTuple
    
    class MyClass(NamedTuple):
      a: int
      b: int
      c: int
    
    make_class = partial(MyClass, 1, c=3)
    print(make_class(b=2))
    # MyClass(a=1, b=2, c=3)
    

    The use in the flax example is conceptually the same: partial(f) returns a function that when called, applies the bound arguments to the original callable, whether it is a function, a method, or a class constructor.

    For example, the ResNet18 function created here:

    ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
                       block_cls=ResNetBlock)
    

    is a partially-evaluated ResNet constructor, and the function is called in a test here:

      @parameterized.product(
          model=(models.ResNet18, models.ResNet18Local)
      )
      def test_resnet_18_v1_model(self, model):
        """Tests ResNet18 V1 model definition and output (variables)."""
        rng = jax.random.PRNGKey(0)
        model_def = model(num_classes=2, dtype=jnp.float32)
        variables = model_def.init(
            rng, jnp.ones((1, 64, 64, 3), jnp.float32))
    
        self.assertLen(variables, 2)
        self.assertLen(variables['params'], 11)
    

    model here is the partially evaluated function ResNet18, and when it is called it returns the fully-instantiated ResNet object with the parameters specified in the ResNet18 partial definition.