[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
73444b7b56
commit
00b31a36a2
@@ -179,6 +179,10 @@ def selective_scan_opcheck_fn(
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
@@ -223,6 +227,10 @@ def selective_scan_opcheck_fn(
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
@@ -338,6 +346,11 @@ def test_selective_scan(
|
||||
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
||||
if c > 0
|
||||
else None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
)
|
||||
outs.append(out)
|
||||
if len(outs) > 1:
|
||||
@@ -372,6 +385,7 @@ def test_selective_scan(
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
ssm_states=state,
|
||||
block_size=2048,
|
||||
)
|
||||
|
||||
|
||||
@@ -586,6 +600,7 @@ def test_selective_scan_varlen(
|
||||
padded_state_indices,
|
||||
has_initial_state,
|
||||
prev_state,
|
||||
block_size=2048,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user