[Llama4,Quantization] Simplify and generalize logic for Q/K permutations in quantized self-attn layers (#34471)
Signed-off-by: Your Name <you@example.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -44,9 +44,6 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors import (
|
|
||||||
compressed_tensors as ct,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
@@ -831,74 +828,38 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
|
|||||||
name: str,
|
name: str,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
) -> tuple[str, torch.Tensor]:
|
) -> tuple[str, torch.Tensor]:
|
||||||
# Helper function to permute the weight's channels
|
|
||||||
def permute(
|
|
||||||
w: torch.Tensor,
|
|
||||||
n_heads: int,
|
|
||||||
is_nvfp4_weight_scale: bool,
|
|
||||||
is_ct_int8_or_fp8_weight_scale: bool,
|
|
||||||
):
|
|
||||||
# Calculate the expected shape of the weight.
|
|
||||||
# Do not rely on w's shape, as it may be in another layout.
|
|
||||||
attn_in = self.config.head_dim * n_heads
|
|
||||||
attn_out = (
|
|
||||||
self.config.hidden_size
|
|
||||||
if not is_ct_int8_or_fp8_weight_scale
|
|
||||||
else w.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the weight is FP4 packed as uint8, we need to divide attn_out
|
|
||||||
# by 2.
|
|
||||||
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
|
|
||||||
attn_out = attn_out // 2
|
|
||||||
|
|
||||||
# If the weight is a weight scale, we need to divide attn_out by
|
|
||||||
# block size, which is currently 16.
|
|
||||||
elif (
|
|
||||||
w.dtype == torch.float8_e4m3fn
|
|
||||||
and is_nvfp4_weight_scale
|
|
||||||
and w.shape[1] * 16 == attn_out
|
|
||||||
):
|
|
||||||
attn_out = attn_out // 16
|
|
||||||
|
|
||||||
return (
|
|
||||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.reshape(attn_in, attn_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
modules = name.split(".")
|
modules = name.split(".")
|
||||||
|
# Permute Q/K weights and corresponding scales for rotary embedding.
|
||||||
# Permute Q/K weights and weight block scales for rotary embedding
|
# This pathway is validated against modelopt and compressed-tensors ckpts,
|
||||||
is_weight = modules[-1] == "weight"
|
# and for per-tensor, per-group (e.g. GPTQ), and per-channel quant schemes.
|
||||||
is_nvfp4_weight_scale = (
|
# Note: permutations are not feasible only for per-block (e.g. DeepSeek 128x128)
|
||||||
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
|
# For per-block quantization, consider not quantizing q/k_proj.
|
||||||
|
is_weight = modules[-1] in ("weight", "weight_packed")
|
||||||
|
is_weight_scale = (
|
||||||
|
modules[-1] == "weight_scale"
|
||||||
|
and loaded_weight.numel() > 1 # no need to permute per-tensor scales
|
||||||
)
|
)
|
||||||
is_ct_int8_or_fp8_weight_scale = False
|
is_k_proj = "wk" in modules or "k_proj" in modules
|
||||||
if modules[-1] == "weight_scale" and isinstance(
|
is_q_proj = "wq" in modules or "q_proj" in modules
|
||||||
self.model.quant_config, ct.CompressedTensorsConfig
|
|
||||||
):
|
|
||||||
from compressed_tensors import CompressionFormat
|
|
||||||
|
|
||||||
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
|
if (is_weight or is_weight_scale) and (is_k_proj or is_q_proj):
|
||||||
CompressionFormat.int_quantized.value,
|
original_ndim = loaded_weight.ndim
|
||||||
CompressionFormat.float_quantized.value,
|
if original_ndim == 1:
|
||||||
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
loaded_weight = loaded_weight.unsqueeze(-1)
|
||||||
|
|
||||||
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
|
f_out, f_in = loaded_weight.shape
|
||||||
if "wk" in modules or "k_proj" in modules:
|
n_heads = (
|
||||||
loaded_weight = permute(
|
self.config.num_key_value_heads
|
||||||
loaded_weight,
|
if is_k_proj
|
||||||
self.config.num_key_value_heads,
|
else self.config.num_attention_heads
|
||||||
is_nvfp4_weight_scale,
|
)
|
||||||
is_ct_int8_or_fp8_weight_scale,
|
loaded_weight = (
|
||||||
)
|
loaded_weight.view(n_heads, f_out // n_heads // 2, 2, f_in)
|
||||||
elif "wq" in modules or "q_proj" in modules:
|
.transpose(1, 2)
|
||||||
loaded_weight = permute(
|
.reshape(f_out, f_in)
|
||||||
loaded_weight,
|
)
|
||||||
self.config.num_attention_heads,
|
|
||||||
is_nvfp4_weight_scale,
|
if original_ndim == 1:
|
||||||
is_ct_int8_or_fp8_weight_scale,
|
loaded_weight = loaded_weight.squeeze(-1)
|
||||||
)
|
|
||||||
|
|
||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|||||||
Reference in New Issue
Block a user