[Bugfix] Fix NemotronH MTP + Chunked Prefill (#35447)

This commit is contained in:
Benjamin Chislett
2026-03-17 02:07:33 -04:00
committed by GitHub
parent 20b14095a4
commit 8a680463fa
5 changed files with 181 additions and 8 deletions

View File

@@ -414,8 +414,11 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
]
state_indices_tensor_p = state_indices_tensor_p[:, 0]
if num_decodes > 0 and self.use_spec_decode:
assert num_accepted_tokens is not None
# Sometimes even with specdec enabled we get single-token prefill chunks that
# should be treated as decodes but don't have num_accepted_tokens set.
# These should be fine to process as non-spec decodes since there's only
# one token, so no risk of placing accepted tokens in the wrong slot.
if num_decodes > 0 and self.use_spec_decode and num_accepted_tokens is not None:
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
num_accepted_tokens = num_accepted_tokens[:num_decodes]
@@ -501,9 +504,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
if self.use_spec_decode:
if self.use_spec_decode and num_accepted_tokens is not None:
assert query_start_loc_d is not None
assert num_accepted_tokens is not None
query_start_loc_d = query_start_loc_d[: padded_bs + 1]
self.decode_num_accepted_tokens[: metadata.num_decodes].copy_(
num_accepted_tokens, non_blocking=True