I would like to use python datatypes - both built-in and imported from libraries such as numpy, tensorflow, etc - as arguments in my hydra configuration. Something like:
# config.yaml
arg1: np.float32
arg2: tf.float16
I'm currently doing this instead:
# config.yaml
arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
'float32': np.float32,
'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]
Is there a more hydronic/elegant solution?
Does the hydra.utils.get_class
function solve this problem for you?
# config.yaml
arg1: numpy.float32 # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)
Based on miccio's comment below, here is a demonstration using an OmegaConf custom resolver to wrap the get_class
function.
from omegaconf import OmegaConf
from hydra.utils import get_class
OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))
config = OmegaConf.create("""
# config.yaml
arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")
arg1 = config.arg1
arg1 = config.arg2
It turns out that get_class("numpy.float32")
succeeds but get_class("tensorflow.float16")
raises a ValueError.
The reason is that get_class
checks that the returned value is indeed a class (using isinstance(cls, type)
).
The function hydra.utils.get_method
is slightly more permissive, checking only that the returned value is a callable, but this still does not work with tf.float16
.
>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False
A custom resolver wrapping the tensorflow.as_dtype
function might be in order.
>>> tf.as_dtype("float16")
tf.float16