[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)

Signed-off-by: Shahar Mor <smor@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Shahar Mor <smor@nvidia.com>
Co-authored-by: Roi Koren <roik@nvidia.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Benjamin Chislett
2026-02-24 12:49:56 -05:00
committed by GitHub
parent a9e15e040d
commit f5972a872f
19 changed files with 799 additions and 157 deletions

View File

@@ -119,7 +119,8 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p
@@ -163,13 +164,6 @@ class ShortConv(MambaBase, CustomOp):
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
conv_output_list = []
if has_prefill: