I'm currently using a GPT-2 model that was trained on German texts. I would like to generate the next word in a text given a context chunk, but instead of using the whole model to predict the next word, I want each of the 12 layers to predict the next word separately, so I get 12 predictions for the next word. Put differently, I want to "lesion" all layers except for one, so they are not involved in the prediction of the next word at all.
This is my model:
# import modules
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM
import torch
# download pre-trained German GPT-2 model & tokenizer
tokenizer = AutoTokenizer.from_pretrained("dbmdz/german-gpt2")
# initialise the model
model = AutoModelForCausalLM.from_pretrained("dbmdz/german-gpt2", pad_token_id = tokenizer.eos_token_id)
And here's an example of a context chunk:
input_text = "Orlando liebte von Natur aus einsame Orte, weite Ausblicke und das Gefühl, für immer und ewig" # correct next word: "allein"
I thought maybe I could set all attention weights to 0 in the layers I want to exclude, but I don't have a clue if that's correct and how to modify the weights in the model. Does anyone have an idea how to solve this & could explain what I need to do? I've never used GPT2 before, so this is super confusing for me.
Thanks in advance for your help / any ideas!
This is technically possibly, but probably won't give you anything useful in understanding your network. You can think of a network like this as computing y = layer(layer(... (layer(layer(x,theta[0]),theta[1]) ...),theta[n-2]),theta[n-1]), where theta[i] are the weights of the ith layer. Setting the weights for a particular layer to 0 would make the input to layer i+1 garbage. There are residual connections between layers, so maybe something non-garbage would happen, but I wouldn't trust it to be useful.
Nonetheless, if you want to see what happens when you zero out all the weights for a layer, you could set weights to 0 using the model's state_dict
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM
import torch
import re
# download pre-trained German GPT-2 model & tokenizer
tokenizer = AutoTokenizer.from_pretrained("dbmdz/german-gpt2")
# initialise the model
model = AutoModelForCausalLM.from_pretrained("dbmdz/german-gpt2", pad_token_id = tokenizer.eos_token_id)
input_text = ["Orlando liebte von Natur aus einsame Orte, weite Ausblicke und das Gefühl, für immer und ewig", # correct next word: "allein"
"Wo sich Fuchs und Hase gute Nacht", # correct next word sagen.
]
prompt = [torch.tensor(tokenizer.encode(s)).unsqueeze(0) for s in input_text]
ngenerate = 20
sample_output0 = [tokenizer.decode(model.generate(s,max_length=s.shape[-1]+ngenerate)[0,:]) for s in prompt]
print('\n***Before zeroing***')
for i,s in enumerate(sample_output0):
print(f'{i}: {s}\n')
# zero-out layer 5
layeri = 5
# find weight names for this layer, will include the string 'transformer.h5.'
paramnames = filter(lambda s: re.search(f'transformer.h\.{layeri}\.',s) is not None,model.state_dict().keys())
# set these weights to 0
for paramname in paramnames:
w = model.state_dict()[paramname]
if w.ndim > 0:
w[:] = 0
# generate some sample output
print('\n***After zeroing***')
sample_output1 = [tokenizer.decode(model.generate(s,max_length=s.shape[-1]+ngenerate)[0,:]) for s in prompt]
print('Before zeroing')
for i,s in enumerate(sample_output1):
print(f'{i}: {s}\n')
The output of this is:
***Before zeroing***
0: Orlando liebte von Natur aus einsame Orte, weite Ausblicke und das Gefühl, für immer und ewig in der Nähe zu sein.
Er war ein großer Künstler, ein Künstler, der sich in der
1: Wo sich Fuchs und Hase gute Nacht sagen.
Die beiden sind seit Jahren befreundet.
Sie sind ein Paar.
Sie sind ein
***After zeroing***
Before zeroing
0: Orlando liebte von Natur aus einsame Orte, weite Ausblicke und das Gefühl, für immer und ewig zu sein.
Die Natur ist ein Paradies für sich.
Die Natur ist ein Paradies für sich
1: Wo sich Fuchs und Hase gute Nacht, die Sonne, die Sonne, die Sonne, die Sonne, die Sonne, die Sonne, die