[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
73444b7b56
commit
00b31a36a2
@@ -241,18 +241,21 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
num_padded_decodes = attn_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -279,12 +282,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
hidden_states_BC,
|
||||
gate,
|
||||
state_indices_tensor,
|
||||
query_start_loc,
|
||||
has_initial_states,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
num_decodes,
|
||||
num_padded_decodes,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
@@ -293,8 +292,34 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
gate_d = prefill_decode_split.gate_d
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
query_start_loc_p = prefill_decode_split.query_start_loc_p
|
||||
has_initial_states_p = prefill_decode_split.has_initial_states_p
|
||||
|
||||
if prefix_caching_enabled:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_scheduled_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
block_idx_first_scheduled_token_p = (
|
||||
attn_metadata.block_idx_first_scheduled_token_p
|
||||
)
|
||||
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||
else:
|
||||
block_idx_last_computed_token_d = None
|
||||
block_idx_last_computed_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_scheduled_token_p = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
ssm_outputs = []
|
||||
|
||||
@@ -309,6 +334,11 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
num_computed_tokens=num_computed_tokens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
)
|
||||
# 3. State Space Model sequence transformations.
|
||||
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||
@@ -331,10 +361,24 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
cache_indices=state_indices_tensor_p,
|
||||
has_initial_state=has_initial_states_p,
|
||||
query_start_loc=query_start_loc_p,
|
||||
block_size=mamba_block_size,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_scheduled_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
else:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_d = causal_conv1d_update(
|
||||
hidden_states_BC_d.transpose(0, 1),
|
||||
@@ -343,6 +387,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
).transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation.
|
||||
@@ -364,7 +410,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=scan_outputs_d,
|
||||
)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
@@ -423,20 +470,14 @@ class PrefillDecodeSplit(NamedTuple):
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
has_initial_states: torch.Tensor | None,
|
||||
num_prefill_tokens: int,
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
@@ -457,16 +498,6 @@ def split_batch_to_prefill_and_decode(
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
|
||||
if num_prefills > 0
|
||||
else None
|
||||
)
|
||||
has_initial_states_p = (
|
||||
has_initial_states[-num_prefills:]
|
||||
if (has_initial_states is not None and num_prefills > 0)
|
||||
else None
|
||||
)
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
@@ -475,8 +506,6 @@ def split_batch_to_prefill_and_decode(
|
||||
gate_d=gate_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -375,6 +375,10 @@ def selective_scan_fn(
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=1024,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
@@ -397,7 +401,10 @@ def selective_scan_fn(
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state index
|
||||
input and output ssm_state indices
|
||||
- Without APC: (batch,) - single state index per batch item
|
||||
- With APC: (batch, max_positions) - cache block indices for read/write
|
||||
Each non-zero value indicates a cache block to load from and/or write to.
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
@@ -408,6 +415,17 @@ def selective_scan_fn(
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
block_size: int
|
||||
The block size to align the cached states to
|
||||
block_idx_first_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the first
|
||||
cache block to be filled is located.
|
||||
block_idx_last_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the last cache block
|
||||
to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the cache block
|
||||
containing the initial state is located.
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
@@ -448,6 +466,10 @@ def selective_scan_fn(
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
|
||||
Reference in New Issue
Block a user