Revert "[Llama4,Quantization] Simplify and generalize logic for Q/K permutations in quantized self-attn layers " (#34997)
This commit is contained in:
@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors import (
|
||||
compressed_tensors as ct,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -828,38 +831,74 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
modules = name.split(".")
|
||||
# Permute Q/K weights and corresponding scales for rotary embedding.
|
||||
# This pathway is validated against modelopt and compressed-tensors ckpts,
|
||||
# and for per-tensor, per-group (e.g. GPTQ), and per-channel quant schemes.
|
||||
# Note: permutations are not feasible only for per-block (e.g. DeepSeek 128x128)
|
||||
# For per-block quantization, consider not quantizing q/k_proj.
|
||||
is_weight = modules[-1] in ("weight", "weight_packed")
|
||||
is_weight_scale = (
|
||||
modules[-1] == "weight_scale"
|
||||
and loaded_weight.numel() > 1 # no need to permute per-tensor scales
|
||||
)
|
||||
is_k_proj = "wk" in modules or "k_proj" in modules
|
||||
is_q_proj = "wq" in modules or "q_proj" in modules
|
||||
|
||||
if (is_weight or is_weight_scale) and (is_k_proj or is_q_proj):
|
||||
original_ndim = loaded_weight.ndim
|
||||
if original_ndim == 1:
|
||||
loaded_weight = loaded_weight.unsqueeze(-1)
|
||||
|
||||
f_out, f_in = loaded_weight.shape
|
||||
n_heads = (
|
||||
self.config.num_key_value_heads
|
||||
if is_k_proj
|
||||
else self.config.num_attention_heads
|
||||
# Helper function to permute the weight's channels
|
||||
def permute(
|
||||
w: torch.Tensor,
|
||||
n_heads: int,
|
||||
is_nvfp4_weight_scale: bool,
|
||||
is_ct_int8_or_fp8_weight_scale: bool,
|
||||
):
|
||||
# Calculate the expected shape of the weight.
|
||||
# Do not rely on w's shape, as it may be in another layout.
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = (
|
||||
self.config.hidden_size
|
||||
if not is_ct_int8_or_fp8_weight_scale
|
||||
else w.shape[-1]
|
||||
)
|
||||
loaded_weight = (
|
||||
loaded_weight.view(n_heads, f_out // n_heads // 2, 2, f_in)
|
||||
|
||||
# If the weight is FP4 packed as uint8, we need to divide attn_out
|
||||
# by 2.
|
||||
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
|
||||
attn_out = attn_out // 2
|
||||
|
||||
# If the weight is a weight scale, we need to divide attn_out by
|
||||
# block size, which is currently 16.
|
||||
elif (
|
||||
w.dtype == torch.float8_e4m3fn
|
||||
and is_nvfp4_weight_scale
|
||||
and w.shape[1] * 16 == attn_out
|
||||
):
|
||||
attn_out = attn_out // 16
|
||||
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(f_out, f_in)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
if original_ndim == 1:
|
||||
loaded_weight = loaded_weight.squeeze(-1)
|
||||
modules = name.split(".")
|
||||
|
||||
# Permute Q/K weights and weight block scales for rotary embedding
|
||||
is_weight = modules[-1] == "weight"
|
||||
is_nvfp4_weight_scale = (
|
||||
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
|
||||
)
|
||||
is_ct_int8_or_fp8_weight_scale = False
|
||||
if modules[-1] == "weight_scale" and isinstance(
|
||||
self.model.quant_config, ct.CompressedTensorsConfig
|
||||
):
|
||||
from compressed_tensors import CompressionFormat
|
||||
|
||||
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
|
||||
if "wk" in modules or "k_proj" in modules:
|
||||
loaded_weight = permute(
|
||||
loaded_weight,
|
||||
self.config.num_key_value_heads,
|
||||
is_nvfp4_weight_scale,
|
||||
is_ct_int8_or_fp8_weight_scale,
|
||||
)
|
||||
elif "wq" in modules or "q_proj" in modules:
|
||||
loaded_weight = permute(
|
||||
loaded_weight,
|
||||
self.config.num_attention_heads,
|
||||
is_nvfp4_weight_scale,
|
||||
is_ct_int8_or_fp8_weight_scale,
|
||||
)
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
Reference in New Issue
Block a user