[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)

Signed-off-by: Shahar Mor <smor@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Shahar Mor <smor@nvidia.com>
Co-authored-by: Roi Koren <roik@nvidia.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Benjamin Chislett
2026-02-24 12:49:56 -05:00
committed by GitHub
parent a9e15e040d
commit f5972a872f
19 changed files with 799 additions and 157 deletions

View File

@@ -477,7 +477,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
dim=-1,
)
compilation_config = get_current_vllm_config().compilation_config
vllm_config = get_current_vllm_config()
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@@ -488,6 +489,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
self.cache_config = cache_config
self.prefix = prefix
self.num_spec = vllm_config.num_speculative_tokens
# Pre-compute sizes for forward pass
self.tped_intermediate_size = self.intermediate_size // self.tp_size
self.tped_conv_size = self.conv_dim // self.tp_size
@@ -576,7 +579,6 @@ class MambaMixer2(MambaBase, PluggableLayer):
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
@@ -584,6 +586,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
num_accepted_tokens = attn_metadata.num_accepted_tokens
query_start_loc_d = attn_metadata.query_start_loc_d
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
if attn_metadata is None:
# profile run
@@ -593,29 +601,21 @@ class MambaMixer2(MambaBase, PluggableLayer):
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
return hidden_states
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
num_actual_tokens = num_prefill_tokens + num_decode_tokens
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
[num_decode_tokens, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
[num_decode_tokens, num_prefill_tokens],
dim=0,
)
@@ -642,16 +642,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_computed_token_d = None
num_computed_tokens_p = None
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
output[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
[num_decode_tokens, num_prefill_tokens],
dim=0,
)
@@ -709,6 +709,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
)
# NOTE: final output is an in-place update of out tensor
assert preallocated_ssm_out_p is not None
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
@@ -840,6 +841,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
num_accepted_tokens=num_accepted_tokens,
query_start_loc=query_start_loc_d,
max_query_len=state_indices_tensor_d.size(-1),
)
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
@@ -862,6 +866,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
-1, self.num_heads // self.tp_size, self.head_dim
)
assert preallocated_ssm_out_d is not None
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
@@ -879,7 +884,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
out=preallocated_ssm_out_d.view(num_decode_tokens, -1, self.head_dim),
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell,
)
@@ -901,6 +908,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
num_spec=self.num_spec,
)
@property