[Bugfix] Fix MLA attention crash with AWQ/GPTQ quantized models (#34695)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||
self.is_aiter_triton_fp4_bmm_enabled = (
|
||||
rocm_aiter_ops.is_fp4bmm_enabled()
|
||||
and hasattr(self.kv_b_proj, "weight")
|
||||
and self.kv_b_proj.weight.dtype == torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
|
||||
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
|
||||
# model dtype input and will quantize internally.
|
||||
if (
|
||||
use_fp8_prefill
|
||||
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype()
|
||||
):
|
||||
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
|
||||
# For quantized layers (AWQ/GPTQ) that lack a .weight attribute,
|
||||
# use params_dtype which is the expected input dtype.
|
||||
_kv_b_proj_w_dtype = (
|
||||
self.kv_b_proj.weight.dtype
|
||||
if hasattr(self.kv_b_proj, "weight")
|
||||
else self.kv_b_proj.params_dtype
|
||||
)
|
||||
if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype():
|
||||
kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype)
|
||||
|
||||
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
|
||||
Reference in New Issue
Block a user