[Hybrid] Enable mamba prefix cache "align" mode with async scheduling (#33997)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2026-02-14 22:15:56 +01:00
committed by GitHub
parent 73391a1baa
commit d5fe3f702c
4 changed files with 77 additions and 32 deletions

View File

@@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput:
),
logprobs_tensors=None,
)
num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1
accpeted_tokens = prompt_token_ids[
first_token_id_index : first_token_id_index
+ min(num_accepted_tokens, logits.shape[0])
]
sampled_token_ids = accpeted_tokens + [-1] * (
num_sampled_tokens - len(accpeted_tokens)
)
sampled_token_ids = accpeted_tokens
return SamplerOutput(
sampled_token_ids=torch.tensor(
[sampled_token_ids], device="cuda", dtype=torch.int32
@@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn():
first_token_id_index : first_token_id_index + num_speculative_tokens
]
]
return proposed_draft_token_ids
next_token_ids = torch.tensor(
prompt_token_ids[
first_token_id_index - 1 : first_token_id_index
- 1
+ num_accepted_tokens
],
device="cuda",
dtype=torch.int32,
)
valid_sampled_tokens_count = torch.tensor(
[num_accepted_tokens], device="cuda", dtype=torch.int32
)
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
return fake_propose_draft_token_ids_fn
@@ -184,6 +198,7 @@ mamba_kv_cache_dict = {}
def get_fake_execute_model_fn(original_execute_model_fn: Callable):
last_num_computed_tokens = 0
num_prompt_tokens = None
def fake_execute_model_fn(
self: GPUModelRunner,
@@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
mamba_group_id
].layer_names[0]
nonlocal last_num_computed_tokens
nonlocal num_prompt_tokens
if (
len(scheduler_output.scheduled_new_reqs) > 0
and scheduler_output.scheduled_new_reqs[0].prompt_token_ids is not None
):
# record number of prompt tokens
num_prompt_tokens = len(
scheduler_output.scheduled_new_reqs[0].prompt_token_ids
)
if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
num_computed_tokens = (
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
)
if (
self.num_spec_tokens
and num_prompt_tokens is not None
and num_computed_tokens > num_prompt_tokens
):
# NOTE (tdoublep) with async scheduling, the scheduler does not have an
# accurate measure of the number of computed tokens; we need to subtract
# the number of reject tokens from the previous timestep.
num_computed_tokens -= num_speculative_tokens + 1 - num_accepted_tokens
if (
num_computed_tokens // BLOCK_SIZE
> last_num_computed_tokens // BLOCK_SIZE
@@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(555, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
StepAction(557, 4, [], (0, 1), (-1, -1)),
StepAction(558, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [], (-1, -1), (1, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
@@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)),
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(558, 4, [], (1, 1), (2, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
@@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
StepAction(559, 4, [], (-1, -1), (1, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(561, 4, [], (-1, -1), (-1, -1)),
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_1": TestConfig(
@@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions=[
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(553, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(559, 4, [], (2, 1), (1, 0)),
StepAction(562, 4, [], (-1, -1), (-1, -1)),
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_2": TestConfig(
@@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(561, 4, [], (-1, -1), (-1, -1)),
StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_1": TestConfig(
@@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(553, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(565, 4, [], (-1, -1), (-1, -1)),
StepAction(561, 4, [], (-1, -1), (-1, -1)),
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_2": TestConfig(
@@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(566, 4, [], (-1, -1), (-1, -1)),
StepAction(562, 4, [], (-1, -1), (-1, -1)),
StepAction(566, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_3": TestConfig(
@@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(563, 4, [], (-1, -1), (-1, -1)),
StepAction(567, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_4": TestConfig(