[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8) (#34243)

Signed-off-by: Your Name <you@example.com>
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Eldar Kurtić
2026-02-11 19:24:22 +01:00
committed by GitHub
parent be7f3d5d20
commit 11c7ace340

View File

@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors import (
compressed_tensors as ct,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
@@ -829,11 +832,20 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
# Helper function to permute the weight's channels
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
def permute(
w: torch.Tensor,
n_heads: int,
is_nvfp4_weight_scale: bool,
is_ct_int8_or_fp8_weight_scale: bool,
):
# Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout.
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
attn_out = (
self.config.hidden_size
if not is_ct_int8_or_fp8_weight_scale
else w.shape[-1]
)
# If the weight is FP4 packed as uint8, we need to divide attn_out
# by 2.
@@ -844,7 +856,7 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
# block size, which is currently 16.
elif (
w.dtype == torch.float8_e4m3fn
and is_weight_scale
and is_nvfp4_weight_scale
and w.shape[1] * 16 == attn_out
):
attn_out = attn_out // 16
@@ -862,19 +874,31 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
is_nvfp4_weight_scale = (
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
)
is_ct_int8_or_fp8_weight_scale = False
if modules[-1] == "weight_scale" and isinstance(
self.model.quant_config, ct.CompressedTensorsConfig
):
from compressed_tensors import CompressionFormat
if is_weight or is_nvfp4_weight_scale:
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
if "wk" in modules or "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
elif "wq" in modules or "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
return name, loaded_weight