[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)

Signed-off-by: Shahar Mor <smor@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Shahar Mor <smor@nvidia.com>
Co-authored-by: Roi Koren <roik@nvidia.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Benjamin Chislett
2026-02-24 12:49:56 -05:00
committed by GitHub
parent a9e15e040d
commit f5972a872f
19 changed files with 799 additions and 157 deletions

View File

@@ -265,7 +265,8 @@ class MambaMixer(MambaBase, PluggableLayer):
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor = attn_metadata.state_indices_tensor
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
@@ -295,17 +296,13 @@ class MambaMixer(MambaBase, PluggableLayer):
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
gate,
state_indices_tensor,
num_prefill_tokens,
num_prefills,
num_decode_tokens,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
@@ -477,16 +474,12 @@ class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int,
num_prefills: int,
num_decode_tokens: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_decode_tokens
@@ -501,20 +494,11 @@ def split_batch_to_prefill_and_decode(
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
)
# num_decode_tokens accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_decode_tokens + num_prefills],
[num_decode_tokens, num_prefills],
dim=0,
)
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
)