I'm pretty new to hydra and was wondering if the following thing is was possible: I have the parameter num_atom_feats
in the model
section which I would like to make dependent on the feat_type
parameter in the data
section. In particular, if I have feat_type: type1
then I would like to have num_atom_feats:22
. If instead, I initialize data
with feat_type : type2
then I would like to have num_atom_feats:200
model:
_target_: model.EmbNet_Lightning
model_name: 'EmbNet'
num_atom_feats: 22
dim_target: 128
loss: 'log_ratio'
lr: 1e-3
wd: 5e-6
data:
_target_: data.DataModule
feat_type: 'type1'
batch_size: 64
data_path: '.'
wandb:
_target_: pytorch_lightning.loggers.WandbLogger
name: embnet_logger
project: ''
trainer:
max_epochs: 1000
You can achieve this using OmeagConf's custom resolver
feature.
Here's an example showing how to register a custom resolver that computes model.num_atom_feat
based on the value of data.feat_type
:
from omegaconf import OmegaConf
yaml_data = """
model:
_target_: model.EmbNet_Lightning
model_name: 'EmbNet'
num_atom_feats: ${compute_num_atom_feats:${data.feat_type}}
data:
_target_: data.DataModule
feat_type: 'type1'
"""
def compute_num_atom_feats(feat_type: str) -> int:
if feat_type == "type1":
return 22
if feat_type == "type2":
return 200
assert False
OmegaConf.register_new_resolver("compute_num_atom_feats", compute_num_atom_feats)
cfg = OmegaConf.create(yaml_data)
assert cfg.data.feat_type == 'type1'
assert cfg.model.num_atom_feats == 22
cfg.data.feat_type = 'type2'
assert cfg.model.num_atom_feats == 200
I'd recommend reading through the docs of OmegaConf, which is the backend used by Hydra.
The compute_num_atom_feats
function is invoked lazily when you access cfg.data.num_atom_feats
in your python code.
When using custom resolvers with Hydra, you can call OmegaConf.register_new_resolver
either before you invoke your @hydra.main
-decorated function, or from within the @hydra.main
-decorated function itself. The important thing is that you call OmegaConf.register_new_resolver
before you access cfg.data.num_atom_feats
.