Revert "[Llama4,Quantization] Simplify and generalize logic for Q/K permutations in quantized self-attn layers " (#34997)

This commit is contained in:
Lucas Wilkinson
2026-02-20 20:19:19 -05:00
committed by GitHub
parent ea5f903f80
commit 0e22cd618b

View File

@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors import (
compressed_tensors as ct,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
@@ -828,38 +831,74 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
name: str,
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
modules = name.split(".")
# Permute Q/K weights and corresponding scales for rotary embedding.
# This pathway is validated against modelopt and compressed-tensors ckpts,
# and for per-tensor, per-group (e.g. GPTQ), and per-channel quant schemes.
# Note: permutations are not feasible only for per-block (e.g. DeepSeek 128x128)
# For per-block quantization, consider not quantizing q/k_proj.
is_weight = modules[-1] in ("weight", "weight_packed")
is_weight_scale = (
modules[-1] == "weight_scale"
and loaded_weight.numel() > 1 # no need to permute per-tensor scales
)
is_k_proj = "wk" in modules or "k_proj" in modules
is_q_proj = "wq" in modules or "q_proj" in modules
if (is_weight or is_weight_scale) and (is_k_proj or is_q_proj):
original_ndim = loaded_weight.ndim
if original_ndim == 1:
loaded_weight = loaded_weight.unsqueeze(-1)
f_out, f_in = loaded_weight.shape
n_heads = (
self.config.num_key_value_heads
if is_k_proj
else self.config.num_attention_heads
# Helper function to permute the weight's channels
def permute(
w: torch.Tensor,
n_heads: int,
is_nvfp4_weight_scale: bool,
is_ct_int8_or_fp8_weight_scale: bool,
):
# Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout.
attn_in = self.config.head_dim * n_heads
attn_out = (
self.config.hidden_size
if not is_ct_int8_or_fp8_weight_scale
else w.shape[-1]
)
loaded_weight = (
loaded_weight.view(n_heads, f_out // n_heads // 2, 2, f_in)
# If the weight is FP4 packed as uint8, we need to divide attn_out
# by 2.
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
attn_out = attn_out // 2
# If the weight is a weight scale, we need to divide attn_out by
# block size, which is currently 16.
elif (
w.dtype == torch.float8_e4m3fn
and is_nvfp4_weight_scale
and w.shape[1] * 16 == attn_out
):
attn_out = attn_out // 16
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(f_out, f_in)
.reshape(attn_in, attn_out)
)
if original_ndim == 1:
loaded_weight = loaded_weight.squeeze(-1)
modules = name.split(".")
# Permute Q/K weights and weight block scales for rotary embedding
is_weight = modules[-1] == "weight"
is_nvfp4_weight_scale = (
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
)
is_ct_int8_or_fp8_weight_scale = False
if modules[-1] == "weight_scale" and isinstance(
self.model.quant_config, ct.CompressedTensorsConfig
):
from compressed_tensors import CompressionFormat
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
if "wk" in modules or "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
elif "wq" in modules or "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
return name, loaded_weight