[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8) (#34243)
Signed-off-by: Your Name <you@example.com> Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors import (
|
||||
compressed_tensors as ct,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -829,11 +832,20 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
# Helper function to permute the weight's channels
|
||||
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
|
||||
def permute(
|
||||
w: torch.Tensor,
|
||||
n_heads: int,
|
||||
is_nvfp4_weight_scale: bool,
|
||||
is_ct_int8_or_fp8_weight_scale: bool,
|
||||
):
|
||||
# Calculate the expected shape of the weight.
|
||||
# Do not rely on w's shape, as it may be in another layout.
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
attn_out = (
|
||||
self.config.hidden_size
|
||||
if not is_ct_int8_or_fp8_weight_scale
|
||||
else w.shape[-1]
|
||||
)
|
||||
|
||||
# If the weight is FP4 packed as uint8, we need to divide attn_out
|
||||
# by 2.
|
||||
@@ -844,7 +856,7 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
|
||||
# block size, which is currently 16.
|
||||
elif (
|
||||
w.dtype == torch.float8_e4m3fn
|
||||
and is_weight_scale
|
||||
and is_nvfp4_weight_scale
|
||||
and w.shape[1] * 16 == attn_out
|
||||
):
|
||||
attn_out = attn_out // 16
|
||||
@@ -862,19 +874,31 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
|
||||
is_nvfp4_weight_scale = (
|
||||
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
|
||||
)
|
||||
is_ct_int8_or_fp8_weight_scale = False
|
||||
if modules[-1] == "weight_scale" and isinstance(
|
||||
self.model.quant_config, ct.CompressedTensorsConfig
|
||||
):
|
||||
from compressed_tensors import CompressionFormat
|
||||
|
||||
if is_weight or is_nvfp4_weight_scale:
|
||||
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
|
||||
if "wk" in modules or "k_proj" in modules:
|
||||
loaded_weight = permute(
|
||||
loaded_weight,
|
||||
self.config.num_key_value_heads,
|
||||
is_nvfp4_weight_scale,
|
||||
is_ct_int8_or_fp8_weight_scale,
|
||||
)
|
||||
elif "wq" in modules or "q_proj" in modules:
|
||||
loaded_weight = permute(
|
||||
loaded_weight,
|
||||
self.config.num_attention_heads,
|
||||
is_nvfp4_weight_scale,
|
||||
is_ct_int8_or_fp8_weight_scale,
|
||||
)
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
Reference in New Issue
Block a user