From the CLI, I am trying to overwrite a group of parameters. The structure of my conf is:
conf
├── config.yaml
├── optimizer
│ ├── adamw.yaml
│ ├── adam.yaml
│ ├── default.yaml
│ └── sgd.yaml
├── task
│ ├── default.yaml
│ └── nlp
│ ├── default_seq2seq.yaml
│ ├── summarization.yaml
│ └── text_classification.yaml
task/default.yaml
:
# @package task
defaults:
- _self_
- /optimizer/adam@cfg.optimizer
_target_: src.core.task.Task
_recursive_: false
cfg:
prefix_sep: ${training.prefix_sep}
optimizer/default.yaml
:
_target_: null
lr: ${training.lr}
weight_decay: 0.001
no_decay:
- bias
- LayerNorm.weight
optimizer/adam.yaml
:
defaults:
- default
_target_: torch.optim.Adam
This results in the output config:
task:
_target_: src.task.nlp.nli_generation.task.NLIGenerationTask
_recursive_: false
cfg:
prefix_sep: ${training.prefix_sep}
optimizer:
_target_: torch.optim.Adam
lr: ${training.lr}
weight_decay: 0.001
no_decay:
- bias
- LayerNorm.weight
I would like to be able to modify the optimizer via the CLI (say, use sgd), but I am not sure how to achieve this. I tried, but I understand why it fails, this
python train.py task.cfg.optimizer=sgd # fails
python train.py task.cfg.optimizer=/optimizer/sgd #fails
Github discussion here.
You can't override default list entries in this form. See this. In particular:
CONFIG : A config to use when creating the output config. e.g. db/mysql, db/mysql@backup.
GROUP_DEFAULT : An overridable config. e.g. db: mysql, db@backup: mysql.
To be able to override a default list entry, you need to define it as a GROUP_DEFAULT. In your case, it might look like
defaults:
- _self_
- /optimizer@cfg.optimizer: adam