pythonjax

Why some nested python functions are defined as `def _():`


I understand internal functions are prefixed with '_' to indicate they are helper/internal functions. It also helps with tooling etc. But I find some functions with just '_' as their name. Can't even find where they are called from. e.g., from

https://github.com/jax-ml/jax/blob/7412adec21c534f8e4bcc627552f28d162decc86/jax/_src/pallas/mosaic/helpers.py#L72

def run_on_first_core(core_axis_name: str):
  """Runs a function on the first core in a given axis."""
  num_cores = jax.lax.axis_size(core_axis_name)
  if num_cores == 1:
    return lambda f: f()

  def wrapped(f):
    core_id = jax.lax.axis_index(core_axis_name)

    @pl_helpers.when(core_id == 0)
    @functools.wraps(f)
    def _(): ## How is this called?
      return f()

  return wrapped

There are several of them in an internal code base but here are some references


Solution

  • A name of _ is different from a name prefixed with _. A name that is only _ means, by convention, "I need to supply a name to satisfy the syntax, but I don't actually need to use the name"*. That would be the case here, since the _ is never actually used anywhere.

    In terms of how this function is actually called, the when decorator appears to be here:

    def when(condition):
      def _wrapped(f):
        if isinstance(condition, bool):
          if condition:
            f()
        else:
          jax.lax.cond(condition, f, lambda: None)
      return _wrapped
    

    You can see that the decorator has a handle on the function via f, and calls it internally if condition is satisfied.


    * I could have sworn that this convention comes from PEP8, but I've skimmed the document twice now, and can't find where it says it.