[Bugfix] Fix MLA attention crash with AWQ/GPTQ quantized models (#34695)

Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
haosdent
2026-03-14 07:25:41 +08:00
committed by GitHub
parent 8b346309a5
commit 6d53efd2a5

View File

@@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
self.is_aiter_triton_fp4_bmm_enabled = (
rocm_aiter_ops.is_fp4bmm_enabled()
and hasattr(self.kv_b_proj, "weight")
and self.kv_b_proj.weight.dtype == torch.bfloat16
)
@@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
# model dtype input and will quantize internally.
if (
use_fp8_prefill
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype()
):
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
# For quantized layers (AWQ/GPTQ) that lack a .weight attribute,
# use params_dtype which is the expected input dtype.
_kv_b_proj_w_dtype = (
self.kv_b_proj.weight.dtype
if hasattr(self.kv_b_proj, "weight")
else self.kv_b_proj.params_dtype
)
if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype():
kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype)
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(