[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)
Signed-off-by: Alexandre Marques <almarque@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
cc99baf14d
commit
5931b7e5d9
@@ -626,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
@@ -637,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
self.config.num_key_value_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wk" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
self.config.num_attention_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wq" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
|
||||
Reference in New Issue
Block a user