diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 3794bde41..36ee728dc 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported self.is_aiter_triton_fp4_bmm_enabled = ( rocm_aiter_ops.is_fp4bmm_enabled() + and hasattr(self.kv_b_proj, "weight") and self.kv_b_proj.weight.dtype == torch.bfloat16 ) @@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] # When FP8 weights are used without FP8 prefill, kv_b_proj expects # model dtype input and will quantize internally. - if ( - use_fp8_prefill - or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype() - ): - kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype) + # For quantized layers (AWQ/GPTQ) that lack a .weight attribute, + # use params_dtype which is the expected input dtype. + _kv_b_proj_w_dtype = ( + self.kv_b_proj.weight.dtype + if hasattr(self.kv_b_proj, "weight") + else self.kv_b_proj.params_dtype + ) + if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype(): + kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype) k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) kv_nope = self.kv_b_proj(kv_c_normed)[0].view(