[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-11-02 14:16:23 +02:00
committed by GitHub
parent 73444b7b56
commit 00b31a36a2
16 changed files with 442 additions and 153 deletions

View File

@@ -179,6 +179,10 @@ def selective_scan_opcheck_fn(
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
@@ -223,6 +227,10 @@ def selective_scan_opcheck_fn(
has_initial_state,
ssm_states,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
),
test_utils=["test_schema", "test_faketensor"],
)
@@ -338,6 +346,11 @@ def test_selective_scan(
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
if c > 0
else None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
)
outs.append(out)
if len(outs) > 1:
@@ -372,6 +385,7 @@ def test_selective_scan(
delta_bias=delta_bias,
delta_softplus=delta_softplus,
ssm_states=state,
block_size=2048,
)
@@ -586,6 +600,7 @@ def test_selective_scan_varlen(
padded_state_indices,
has_initial_state,
prev_state,
block_size=2048,
)