[Bugfix] Fix sparse MLA metadata building (#33579)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -522,22 +522,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
assert (
|
||||
attn_metadata.num_decodes is not None
|
||||
and attn_metadata.num_prefills is not None
|
||||
and attn_metadata.num_decode_tokens is not None
|
||||
)
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_q = q[:num_decode_tokens]
|
||||
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
@@ -555,27 +539,32 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
# Sparse MLA impls only support forward_mqa (decode-style attention)
|
||||
is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)
|
||||
|
||||
if has_prefill and not is_sparse_impl:
|
||||
if is_sparse_impl:
|
||||
num_mqa_tokens = q.size(0)
|
||||
num_mha_tokens = 0
|
||||
else:
|
||||
assert (
|
||||
attn_metadata.num_decodes is not None
|
||||
and attn_metadata.num_prefills is not None
|
||||
and attn_metadata.num_decode_tokens is not None
|
||||
)
|
||||
num_mqa_tokens = attn_metadata.num_decode_tokens
|
||||
num_mha_tokens = q.size(0) - num_mqa_tokens
|
||||
|
||||
if num_mha_tokens > 0:
|
||||
self.impl.forward_mha(
|
||||
prefill_q,
|
||||
prefill_k_c_normed,
|
||||
prefill_k_pe,
|
||||
q[num_mqa_tokens:],
|
||||
k_c_normed[num_mqa_tokens:],
|
||||
k_pe[num_mqa_tokens:],
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
output=output[num_decode_tokens:],
|
||||
output=output[num_mqa_tokens:],
|
||||
)
|
||||
|
||||
if has_decode or (has_prefill and is_sparse_impl):
|
||||
# For sparse impl, we always use forward_mqa for all tokens
|
||||
# For non-sparse impl, we only use forward_mqa for decode tokens
|
||||
if is_sparse_impl:
|
||||
mqa_q = q
|
||||
mqa_output_slice = output
|
||||
else:
|
||||
assert attn_metadata.decode is not None
|
||||
mqa_q = decode_q
|
||||
mqa_output_slice = output[:num_decode_tokens]
|
||||
if num_mqa_tokens > 0:
|
||||
mqa_q = q[:num_mqa_tokens]
|
||||
mqa_output_slice = output[:num_mqa_tokens]
|
||||
|
||||
mqa_q_nope, mqa_q_pe = mqa_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
@@ -644,6 +633,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
|
||||
|
||||
# call decode attn
|
||||
if not is_sparse_impl:
|
||||
assert attn_metadata.decode is not None
|
||||
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
|
||||
Reference in New Issue
Block a user