[Bugfix] Fix sparse MLA metadata building (#33579)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-02-03 18:29:48 -05:00
committed by GitHub
parent 2a99c5a6c8
commit bd8da29a66

View File

@@ -522,22 +522,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
@@ -555,27 +539,32 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Sparse MLA impls only support forward_mqa (decode-style attention)
is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)
if has_prefill and not is_sparse_impl:
if is_sparse_impl:
num_mqa_tokens = q.size(0)
num_mha_tokens = 0
else:
assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)
num_mqa_tokens = attn_metadata.num_decode_tokens
num_mha_tokens = q.size(0) - num_mqa_tokens
if num_mha_tokens > 0:
self.impl.forward_mha(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
q[num_mqa_tokens:],
k_c_normed[num_mqa_tokens:],
k_pe[num_mqa_tokens:],
kv_cache,
attn_metadata,
self._k_scale,
output=output[num_decode_tokens:],
output=output[num_mqa_tokens:],
)
if has_decode or (has_prefill and is_sparse_impl):
# For sparse impl, we always use forward_mqa for all tokens
# For non-sparse impl, we only use forward_mqa for decode tokens
if is_sparse_impl:
mqa_q = q
mqa_output_slice = output
else:
assert attn_metadata.decode is not None
mqa_q = decode_q
mqa_output_slice = output[:num_decode_tokens]
if num_mqa_tokens > 0:
mqa_q = q[:num_mqa_tokens]
mqa_output_slice = output[:num_mqa_tokens]
mqa_q_nope, mqa_q_pe = mqa_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
@@ -644,6 +633,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
# call decode attn
if not is_sparse_impl:
assert attn_metadata.decode is not None
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
# correct dcp attn_out with lse.