From How does one reinitialize the weights of a Hugging Face LLaMA v2 model the official way as the original model? and https://discuss.huggingface.co/t/how-does-one-reinitialize-the-weights-of-a-hugging-face-llama-v2-model-the-official-way-as-the-original-model/62547/4 there's different suggestions to reinitialize the model.
When I tried this, it seems to work.
from transformers import AutoModelForCausalLM, AutoConfig
m = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.3", token="hf_*****")
c = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.3")
m2 = AutoModelForCausalLM.from_config(c)
print(m2.model.layers[0].mlp.down_proj.state_dict())
print(m.model.layers[0].mlp.down_proj.state_dict())
[out]:
OrderedDict([('weight',
tensor([[ 0.0315, -0.0025, -0.0015, ..., -0.0022, 0.0168, -0.0296],
[-0.0013, -0.0190, -0.0103, ..., 0.0037, 0.0021, -0.0374],
[-0.0378, -0.0230, 0.0031, ..., -0.0035, 0.0099, -0.0027],
...,
[-0.0029, 0.0042, -0.0041, ..., -0.0003, 0.0396, -0.0012],
[-0.0487, -0.0050, -0.0068, ..., 0.0170, 0.0135, -0.0006],
[ 0.0103, 0.0424, 0.0019, ..., 0.0155, 0.0254, 0.0061]]))])
OrderedDict([('weight',
tensor([[-0.0027, -0.0004, -0.0007, ..., -0.0025, 0.0032, -0.0014],
[ 0.0012, -0.0047, 0.0026, ..., -0.0017, 0.0015, -0.0044],
[ 0.0056, -0.0084, 0.0027, ..., 0.0026, -0.0053, 0.0038],
...,
[ 0.0052, 0.0017, -0.0019, ..., -0.0013, 0.0052, -0.0017],
[-0.0032, 0.0029, -0.0014, ..., 0.0003, 0.0006, 0.0023],
[-0.0023, -0.0045, -0.0013, ..., -0.0036, 0.0002, -0.0008]]))])
How are the layers re-initialized through the from_config
function? Is it using Xaiver/He initialization or just random initialization?
MistralConfig has a default parameter initializer_range
which is set to 0.02 and described as The standard deviation of the truncated_normal_initializer for initializing all weight matrices
, so one can assume they use a truncated normal distribution with a standard deviation of 0.02.
If you plot the actual model weights distribution and what a truncated normal distribution with standard deviation of 0.02 looks like, it seems like a fit to me:
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import truncnorm
from transformers import AutoModelForCausalLM, AutoConfig
# histogram of actual weights distribution
c = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.3")
m2 = AutoModelForCausalLM.from_config(c)
weights = m2.model.layers[0].mlp.down_proj.state_dict()['weight'].ravel()
plt.hist(weights, bins=np.linspace(-0.1, 0.1, 100), histtype='step', density=True, label='model weights')
# what a truncated normal distribution with mean 0 and std 0.02 is supposed to look like
lower = -0.1
upper = 0.1
mean = 0
std = 0.02
a, b = (lower - mean) / std, (upper - mean) / std
x = np.linspace(lower, upper, 1000)
plt.plot(x, truncnorm.pdf(x, a, b, loc=mean, scale=std), label='expected')
plt.legend()
plt.show()