I am trying adapters on LIMU-BERT, which is a lightweight BERT for IMU data. I pretrained LIMU-BERT on Dataset A and planned to add adapters and tune them on Dataset B. Here is my adapter-adding code:
import adapters
class AdapterBERTClassifier(nn.Module):
def __init__(self, bert_cfg, classifier=None):
super().__init__()
self.limu_bert = LIMUBertModel4Pretrain(bert_cfg, output_embed=True)
self.classifier = classifier
# Add adapter
adapter_config = adapters.AdapterConfig(
mh_adapter=True,
output_adapter=True,
reduction_factor=16,
non_linearity="relu"
)
self.limu_bert.add_adapter("classification_adapter", config=adapter_config)
self.limu_bert.train_adapter("classification_adapter")
However, I encountered an error:
Traceback (most recent call last):
File "D:\Documents\Code\LIMU-BERT\classifier_adapter.py", line 71, in <module>
label_test, label_estimate_test = bert_classify(args, args.label_index, train_rate, label_rate, balance=balance)
File "D:\Documents\Code\LIMU-BERT\classifier_adapter.py", line 37, in bert_classify
model = AdapterBERTClassifier(model_bert_cfg, classifier=classifier)
File "D:\Documents\Code\LIMU-BERT\models.py", line 332, in __init__
adapter_config = adapters.AdapterConfig(
TypeError: AdapterConfig.__init__() got an unexpected keyword argument 'mh_adapter'
Since the document for Adapter Configuration mentions there is a parameter named mh_adapter
for adapters.AdapterConfig
, can anyone tell me what is the problem and how to solve it? Thank you for your help!
By the way, here is my adapters package info:
# Name Version Build Channel
adapters 1.0.1 pypi_0 pypi
The adapters.AdapterConfig
class that you are using is actually a base class for all adaptation methods. And according to the documentation: "This class does not define specific configuration keys, but only provides some common helper methods." I think this explains why the whole thing.
You don't want to use this base class; instead, you should use an adapter corresponding to your exact use case. The document explains a few adapters that include the mh_adapter
, such as adapters.BnConfig
. You have been looking at the input parameters defined for this class and had mistaken it for the base class.
This is how your code might look like after the modification:
adapter_config = adapters.BnConfig(
mh_adapter=True,
output_adapter=True,
reduction_factor=16,
non_linearity="relu"
)