I am using the following custom feature extractor for my StableBaselines3 model:
import torch.nn as nn
from stable_baselines3 import PPO
class Encoder(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim=2):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, embedding_dim),
nn.ReLU()
)
self.regressor = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.ReLU(),
)
def forward(self, x):
x = self.encoder(x)
x = self.regressor(x)
return x
model = Encoder(input_dim, embedding_dim, hidden_dim)
model.load_state_dict(torch.load('trained_model.pth'))
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
class CustomFeatureExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim):
super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)
self.model = model # Use the pre-trained model as the feature extractor
self._features_dim = features_dim
def forward(self, observations):
features = self.model(observations)
return features
policy_kwargs = {
"features_extractor_class": CustomFeatureExtractor,
"features_extractor_kwargs": {"features_dim": 64}
}
model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs)
The model is trained well so far with no issues and good results. Now I want to not freeze the weights, and try to train the Feature Extractor as well starting from the initial pre-trained weight. How can I do that with such a custom Feature Extractor defined as a class inside another class? My feature extractor is not the same as the one defined in the documentation, so I am not sure if it will be trained. Or will it start training if I unfreeze the layers?
UPDATED answer
Because your CustomFE imports already freezer Encoder (with requires_grad = False
) you have that kind of situation where all weights of CustomFE are frozen. Thus by default CustomFE is not trainable. You will need to unfreeze it manually:
model = PPO("MlpPolicy", env='FrozenLake8x8', policy_kwargs=policy_kwargs)
# get model feature extractor
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
# convert all parameters to trainable
for name, param in feature_extr.named_parameters():
param.requires_grad = True
# check parameters before training
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
print(name, param.mean())
# train the model
model.learn(total_timesteps = 5)
# check parameters after training (if mean changed parameters are training)
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
print(name, param.mean())