[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-11-02 14:16:23 +02:00
committed by GitHub
parent 73444b7b56
commit 00b31a36a2
16 changed files with 442 additions and 153 deletions

View File

@@ -241,18 +241,21 @@ class MambaMixer(MambaBase, CustomOp):
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba1_metadata = attn_metadata
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
query_start_loc = mamba1_metadata.query_start_loc
state_indices_tensor = mamba1_metadata.state_indices_tensor
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor = attn_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes
has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -279,12 +282,8 @@ class MambaMixer(MambaBase, CustomOp):
hidden_states_BC,
gate,
state_indices_tensor,
query_start_loc,
has_initial_states,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
num_decodes,
num_padded_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
@@ -293,8 +292,34 @@ class MambaMixer(MambaBase, CustomOp):
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
query_start_loc_p = prefill_decode_split.query_start_loc_p
has_initial_states_p = prefill_decode_split.has_initial_states_p
if prefix_caching_enabled:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_first_scheduled_token_p = (
attn_metadata.block_idx_first_scheduled_token_p
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
ssm_outputs = []
@@ -309,6 +334,11 @@ class MambaMixer(MambaBase, CustomOp):
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
num_computed_tokens=num_computed_tokens_p,
block_size_to_align=mamba_block_size,
)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
@@ -331,10 +361,24 @@ class MambaMixer(MambaBase, CustomOp):
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p,
block_size=mamba_block_size,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
)
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, block_idx_last_scheduled_token_d.unsqueeze(1)
).squeeze(1)
else:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
@@ -343,6 +387,8 @@ class MambaMixer(MambaBase, CustomOp):
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
).transpose(0, 1)
# 3. State Space Model sequence transformation.
@@ -364,7 +410,8 @@ class MambaMixer(MambaBase, CustomOp):
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
@@ -423,20 +470,14 @@ class PrefillDecodeSplit(NamedTuple):
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: torch.Tensor | None
has_initial_states_p: torch.Tensor | None
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: torch.Tensor | None,
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
@@ -457,16 +498,6 @@ def split_batch_to_prefill_and_decode(
[num_padded_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
if num_prefills > 0
else None
)
has_initial_states_p = (
has_initial_states[-num_prefills:]
if (has_initial_states is not None and num_prefills > 0)
else None
)
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
@@ -475,8 +506,6 @@ def split_batch_to_prefill_and_decode(
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)