pythonpython-3.xfb-hydraomegaconf

use data types or other library-specific variables as arguments in hydra


I would like to use python datatypes - both built-in and imported from libraries such as numpy, tensorflow, etc - as arguments in my hydra configuration. Something like:

# config.yaml

arg1: np.float32
arg2: tf.float16

I'm currently doing this instead:

# config.yaml

arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
  'float32': np.float32,
  'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]

Is there a more hydronic/elegant solution?


Solution

  • Does the hydra.utils.get_class function solve this problem for you?

    # config.yaml
    
    arg1: numpy.float32  # note: use "numpy" here, not "np"
    arg2: tensorflow.float16
    
    # python code
    ...
    from hydra.utils import get_class
    arg1 = get_class(config.arg1)
    arg2 = get_class(config.arg2)
    

    Update 1: using a custom resolver

    Based on miccio's comment below, here is a demonstration using an OmegaConf custom resolver to wrap the get_class function.

    from omegaconf import OmegaConf
    from hydra.utils import get_class
    
    OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))
    
    config = OmegaConf.create("""
    # config.yaml
    
    arg1: "${get_cls: numpy.float32}"
    arg2: "${get_cls: tensorflow.float16}"
    """)
    
    arg1 = config.arg1
    arg1 = config.arg2
    

    Update 2:

    It turns out that get_class("numpy.float32") succeeds but get_class("tensorflow.float16") raises a ValueError. The reason is that get_class checks that the returned value is indeed a class (using isinstance(cls, type)).

    The function hydra.utils.get_method is slightly more permissive, checking only that the returned value is a callable, but this still does not work with tf.float16.

    >>> isinstance(tf.float16, type)
    False
    >>> callable(tf.float16)
    False
    

    A custom resolver wrapping the tensorflow.as_dtype function might be in order.

    >>> tf.as_dtype("float16")
    tf.float16