[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)
Signed-off-by: Shahar Mor <smor@nvidia.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Shahar Mor <smor@nvidia.com> Co-authored-by: Roi Koren <roik@nvidia.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
committed by
GitHub
parent
a9e15e040d
commit
f5972a872f
@@ -477,7 +477,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
@@ -488,6 +489,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
self.num_spec = vllm_config.num_speculative_tokens
|
||||
|
||||
# Pre-compute sizes for forward pass
|
||||
self.tped_intermediate_size = self.intermediate_size // self.tp_size
|
||||
self.tped_conv_size = self.conv_dim // self.tp_size
|
||||
@@ -576,7 +579,6 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
@@ -584,6 +586,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
query_start_loc_d = attn_metadata.query_start_loc_d
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
@@ -593,29 +601,21 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
return hidden_states
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
@@ -642,16 +642,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
)
|
||||
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||
else:
|
||||
block_idx_last_computed_token_d = None
|
||||
block_idx_last_computed_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_scheduled_token_p = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_computed_token_d = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
output[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
@@ -709,6 +709,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
assert preallocated_ssm_out_p is not None
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(
|
||||
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
||||
@@ -840,6 +841,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
query_start_loc=query_start_loc_d,
|
||||
max_query_len=state_indices_tensor_d.size(-1),
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
|
||||
@@ -862,6 +866,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
-1, self.num_heads // self.tp_size, self.head_dim
|
||||
)
|
||||
|
||||
assert preallocated_ssm_out_d is not None
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
@@ -879,7 +884,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
out=preallocated_ssm_out_d.view(num_decode_tokens, -1, self.head_dim),
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
cu_seqlens=query_start_loc_d,
|
||||
is_blackwell=self.is_blackwell,
|
||||
)
|
||||
|
||||
@@ -901,6 +908,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
head_dim=self.head_dim,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
num_spec=self.num_spec,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user