diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 501b939c1..1b719330e 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -522,22 +522,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - assert ( - attn_metadata.num_decodes is not None - and attn_metadata.num_prefills is not None - and attn_metadata.num_decode_tokens is not None - ) - - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - decode_q = q[:num_decode_tokens] - - prefill_q = q[num_decode_tokens:] - prefill_k_pe = k_pe[num_decode_tokens:] - prefill_k_c_normed = k_c_normed[num_decode_tokens:] - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -555,27 +539,32 @@ class MLAAttention(nn.Module, AttentionLayerBase): # Sparse MLA impls only support forward_mqa (decode-style attention) is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl) - if has_prefill and not is_sparse_impl: + if is_sparse_impl: + num_mqa_tokens = q.size(0) + num_mha_tokens = 0 + else: + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + num_mqa_tokens = attn_metadata.num_decode_tokens + num_mha_tokens = q.size(0) - num_mqa_tokens + + if num_mha_tokens > 0: self.impl.forward_mha( - prefill_q, - prefill_k_c_normed, - prefill_k_pe, + q[num_mqa_tokens:], + k_c_normed[num_mqa_tokens:], + k_pe[num_mqa_tokens:], kv_cache, attn_metadata, self._k_scale, - output=output[num_decode_tokens:], + output=output[num_mqa_tokens:], ) - if has_decode or (has_prefill and is_sparse_impl): - # For sparse impl, we always use forward_mqa for all tokens - # For non-sparse impl, we only use forward_mqa for decode tokens - if is_sparse_impl: - mqa_q = q - mqa_output_slice = output - else: - assert attn_metadata.decode is not None - mqa_q = decode_q - mqa_output_slice = output[:num_decode_tokens] + if num_mqa_tokens > 0: + mqa_q = q[:num_mqa_tokens] + mqa_output_slice = output[:num_mqa_tokens] mqa_q_nope, mqa_q_pe = mqa_q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 @@ -644,6 +633,8 @@ class MLAAttention(nn.Module, AttentionLayerBase): mqa_q = get_dcp_group().all_gather(mqa_q, dim=1) # call decode attn + if not is_sparse_impl: + assert attn_metadata.decode is not None attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) # correct dcp attn_out with lse.