I am using model = 'filipealmeida/Mistral-7B-Instruct-v0.1-sharded' and quantize it in 4_bit with the following function.
def load_quantized_model(model_name: str):
"""
:param model_name: Name or path of the model to be loaded.
:return: Loaded quantized model.
"""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
return model
When I load the file I get the following error message:
ValueError Traceback (most recent call last)
Cell In[12], line 1
----> 1 model = load_quantized_model(model_name)
Cell In[10], line 13
2 """
3 :param model_name: Name or path of the model to be loaded.
4 :return: Loaded quantized model.
5 """
6 bnb_config = BitsAndBytesConfig(
7 load_in_4bit=True,
8 bnb_4bit_use_double_quant=True,
9 bnb_4bit_quant_type="nf4",
10 bnb_4bit_compute_dtype=torch.bfloat16
11 )
---> 13 model = AutoModelForCausalLM.from_pretrained(
14 model_name,
15 load_in_4bit=True,
16 torch_dtype=torch.bfloat16,
17 quantization_config=bnb_config
18 )
20 return model
File ~/miniconda3/envs/peft/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:563, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
...
2981 )
2983 # preparing BitsAndBytesConfig from kwargs
2984 config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time.
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
I checked the Class _BaseAutoModelClass.from_pretrained but I cannot find where '8_bit ' is set. What am I expected to do to have the model loaded correctly in 4-bit ?
I tried to change the bnb_config to adapt it to 8_bit but I could not solve the problem.
try to remove "load_in_4bit=True" from AutoModelForCausalLM.from_pretrained()