[Hybrid] Enable mamba prefix cache "align" mode with async scheduling (#33997)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2026-02-14 22:15:56 +01:00
committed by GitHub
parent 73391a1baa
commit d5fe3f702c
4 changed files with 77 additions and 32 deletions

View File

@@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput:
),
logprobs_tensors=None,
)
num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1
accpeted_tokens = prompt_token_ids[
first_token_id_index : first_token_id_index
+ min(num_accepted_tokens, logits.shape[0])
]
sampled_token_ids = accpeted_tokens + [-1] * (
num_sampled_tokens - len(accpeted_tokens)
)
sampled_token_ids = accpeted_tokens
return SamplerOutput(
sampled_token_ids=torch.tensor(
[sampled_token_ids], device="cuda", dtype=torch.int32
@@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn():
first_token_id_index : first_token_id_index + num_speculative_tokens
]
]
return proposed_draft_token_ids
next_token_ids = torch.tensor(
prompt_token_ids[
first_token_id_index - 1 : first_token_id_index
- 1
+ num_accepted_tokens
],
device="cuda",
dtype=torch.int32,
)
valid_sampled_tokens_count = torch.tensor(
[num_accepted_tokens], device="cuda", dtype=torch.int32
)
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
return fake_propose_draft_token_ids_fn
@@ -184,6 +198,7 @@ mamba_kv_cache_dict = {}
def get_fake_execute_model_fn(original_execute_model_fn: Callable):
last_num_computed_tokens = 0
num_prompt_tokens = None
def fake_execute_model_fn(
self: GPUModelRunner,
@@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
mamba_group_id
].layer_names[0]
nonlocal last_num_computed_tokens
nonlocal num_prompt_tokens
if (
len(scheduler_output.scheduled_new_reqs) > 0
and scheduler_output.scheduled_new_reqs[0].prompt_token_ids is not None
):
# record number of prompt tokens
num_prompt_tokens = len(
scheduler_output.scheduled_new_reqs[0].prompt_token_ids
)
if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
num_computed_tokens = (
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
)
if (
self.num_spec_tokens
and num_prompt_tokens is not None
and num_computed_tokens > num_prompt_tokens
):
# NOTE (tdoublep) with async scheduling, the scheduler does not have an
# accurate measure of the number of computed tokens; we need to subtract
# the number of reject tokens from the previous timestep.
num_computed_tokens -= num_speculative_tokens + 1 - num_accepted_tokens
if (
num_computed_tokens // BLOCK_SIZE
> last_num_computed_tokens // BLOCK_SIZE
@@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(555, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
StepAction(557, 4, [], (0, 1), (-1, -1)),
StepAction(558, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [], (-1, -1), (1, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
@@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)),
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(558, 4, [], (1, 1), (2, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
@@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
StepAction(559, 4, [], (-1, -1), (1, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(561, 4, [], (-1, -1), (-1, -1)),
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_1": TestConfig(
@@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
step_actions=[
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(553, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(559, 4, [], (2, 1), (1, 0)),
StepAction(562, 4, [], (-1, -1), (-1, -1)),
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_2": TestConfig(
@@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(561, 4, [], (-1, -1), (-1, -1)),
StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_1": TestConfig(
@@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(553, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(565, 4, [], (-1, -1), (-1, -1)),
StepAction(561, 4, [], (-1, -1), (-1, -1)),
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_2": TestConfig(
@@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(566, 4, [], (-1, -1), (-1, -1)),
StepAction(562, 4, [], (-1, -1), (-1, -1)),
StepAction(566, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_3": TestConfig(
@@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(563, 4, [], (-1, -1), (-1, -1)),
StepAction(567, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_4": TestConfig(

View File

@@ -648,11 +648,6 @@ class VllmConfig:
"`external_launcher` distributed executor backend, but you chose "
f"`{executor_backend}`."
)
if self.cache_config.mamba_cache_mode != "none":
raise ValueError(
"Currently, async scheduling is not compatible with "
"prefix caching for Mamba models."
)
elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option.
if (
@@ -685,13 +680,6 @@ class VllmConfig:
scope="local",
)
self.scheduler_config.async_scheduling = False
elif self.cache_config.mamba_cache_mode != "none":
logger.warning_once(
"Async scheduling is not compatible with "
"prefix caching for Mamba models and will be disabled.",
scope="local",
)
self.scheduler_config.async_scheduling = False
else:
self.scheduler_config.async_scheduling = True

View File

@@ -814,6 +814,14 @@ class MambaManager(SingleTypeKVCacheManager):
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
assert isinstance(self.kv_cache_spec, MambaSpec)
# NOTE (tdoublep) with async scheduling, the num_computed_tokens can contain
# draft tokens from the previous step that may or may not be rejected later.
# This can make us think we are further ahead in the sequence than we actually
# are, so let's assume that all tokens are rejected so we don't free blocks
# that we might actually need.
num_computed_tokens = max(0, num_computed_tokens - self.num_speculative_blocks)
super().remove_skipped_blocks(request_id, num_computed_tokens)
if self.mamba_cache_mode == "align":
# `last_state_block_idx` refers to the block index allocated two steps ago.
@@ -879,6 +887,9 @@ class MambaManager(SingleTypeKVCacheManager):
# We can ignore lookahead tokens because current draft models don't have
# mamba layers.
num_tokens = num_tokens_main_model
# NOTE(tdouble): this is an over-estimate of how many blocks we need because
# num_tokens can include draft tokens that will later be rejected.
num_required_blocks = (
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
)
@@ -922,6 +933,8 @@ class MambaManager(SingleTypeKVCacheManager):
# mamba layers.
num_tokens = num_tokens_main_model
req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id]
# NOTE(tdouble): this is an over-estimate of how many blocks we need because
# num_tokens can include draft tokens that will later be rejected.
num_required_blocks = (
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
)

View File

@@ -10,6 +10,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
)
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.worker.gpu_input_batch import CachedRequestState
@@ -142,7 +143,11 @@ def preprocess_mamba(
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
num_blocks = len(req_state.block_ids[mamba_group_ids[0]])
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_blocks: int = (
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
+ num_speculative_blocks
)
# We always save the current running state at the last
# (1 + num_speculative_blocks) block.