[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)

Signed-off-by: Alexandre Marques <almarque@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Alexandre Marques
2025-09-10 22:13:56 -04:00
committed by GitHub
parent cc99baf14d
commit 5931b7e5d9
2 changed files with 88 additions and 4 deletions

View File

@@ -626,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
@@ -637,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
modules = name.split(".")
# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
self.config.num_key_value_heads,
self.config.hidden_size)
elif "wk" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
self.config.num_attention_heads,
self.config.hidden_size)
elif "wq" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads, 1)
num_modules = len(modules)
for i in range(num_modules):