[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)
Signed-off-by: Shahar Mor <smor@nvidia.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Shahar Mor <smor@nvidia.com> Co-authored-by: Roi Koren <roik@nvidia.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
committed by
GitHub
parent
a9e15e040d
commit
f5972a872f
@@ -265,7 +265,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
@@ -295,17 +296,13 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC,
|
||||
gate,
|
||||
state_indices_tensor,
|
||||
num_prefill_tokens,
|
||||
num_prefills,
|
||||
num_decode_tokens,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
gate_p = prefill_decode_split.gate_p
|
||||
gate_d = prefill_decode_split.gate_d
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
|
||||
if is_mamba_cache_all:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
@@ -477,16 +474,12 @@ class PrefillDecodeSplit(NamedTuple):
|
||||
hidden_states_BC_d: torch.Tensor
|
||||
gate_p: torch.Tensor
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
num_prefill_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decode_tokens: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
@@ -501,20 +494,11 @@ def split_batch_to_prefill_and_decode(
|
||||
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
|
||||
)
|
||||
|
||||
# num_decode_tokens accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[: num_decode_tokens + num_prefills],
|
||||
[num_decode_tokens, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
hidden_states_BC_d=hidden_states_BC_d,
|
||||
gate_p=gate_p,
|
||||
gate_d=gate_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user