[Hybrid] Enable mamba prefix cache "align" mode with async scheduling (#33997)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput:
|
||||
),
|
||||
logprobs_tensors=None,
|
||||
)
|
||||
num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1
|
||||
accpeted_tokens = prompt_token_ids[
|
||||
first_token_id_index : first_token_id_index
|
||||
+ min(num_accepted_tokens, logits.shape[0])
|
||||
]
|
||||
sampled_token_ids = accpeted_tokens + [-1] * (
|
||||
num_sampled_tokens - len(accpeted_tokens)
|
||||
)
|
||||
sampled_token_ids = accpeted_tokens
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=torch.tensor(
|
||||
[sampled_token_ids], device="cuda", dtype=torch.int32
|
||||
@@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn():
|
||||
first_token_id_index : first_token_id_index + num_speculative_tokens
|
||||
]
|
||||
]
|
||||
return proposed_draft_token_ids
|
||||
|
||||
next_token_ids = torch.tensor(
|
||||
prompt_token_ids[
|
||||
first_token_id_index - 1 : first_token_id_index
|
||||
- 1
|
||||
+ num_accepted_tokens
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
valid_sampled_tokens_count = torch.tensor(
|
||||
[num_accepted_tokens], device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
|
||||
|
||||
return fake_propose_draft_token_ids_fn
|
||||
|
||||
@@ -184,6 +198,7 @@ mamba_kv_cache_dict = {}
|
||||
|
||||
def get_fake_execute_model_fn(original_execute_model_fn: Callable):
|
||||
last_num_computed_tokens = 0
|
||||
num_prompt_tokens = None
|
||||
|
||||
def fake_execute_model_fn(
|
||||
self: GPUModelRunner,
|
||||
@@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
|
||||
mamba_group_id
|
||||
].layer_names[0]
|
||||
nonlocal last_num_computed_tokens
|
||||
nonlocal num_prompt_tokens
|
||||
|
||||
if (
|
||||
len(scheduler_output.scheduled_new_reqs) > 0
|
||||
and scheduler_output.scheduled_new_reqs[0].prompt_token_ids is not None
|
||||
):
|
||||
# record number of prompt tokens
|
||||
num_prompt_tokens = len(
|
||||
scheduler_output.scheduled_new_reqs[0].prompt_token_ids
|
||||
)
|
||||
|
||||
if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
|
||||
num_computed_tokens = (
|
||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
|
||||
)
|
||||
if (
|
||||
self.num_spec_tokens
|
||||
and num_prompt_tokens is not None
|
||||
and num_computed_tokens > num_prompt_tokens
|
||||
):
|
||||
# NOTE (tdoublep) with async scheduling, the scheduler does not have an
|
||||
# accurate measure of the number of computed tokens; we need to subtract
|
||||
# the number of reject tokens from the previous timestep.
|
||||
num_computed_tokens -= num_speculative_tokens + 1 - num_accepted_tokens
|
||||
if (
|
||||
num_computed_tokens // BLOCK_SIZE
|
||||
> last_num_computed_tokens // BLOCK_SIZE
|
||||
@@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
step_actions=[
|
||||
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
|
||||
StepAction(557, 4, [], (0, 1), (-1, -1)),
|
||||
StepAction(558, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [], (-1, -1), (1, 0)),
|
||||
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
||||
@@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
step_actions=[
|
||||
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)),
|
||||
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [], (1, 1), (2, 0)),
|
||||
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
@@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
|
||||
StepAction(559, 4, [], (-1, -1), (1, 0)),
|
||||
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(561, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_3_1": TestConfig(
|
||||
@@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
step_actions=[
|
||||
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(553, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)),
|
||||
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [], (2, 1), (1, 0)),
|
||||
StepAction(562, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_3_2": TestConfig(
|
||||
@@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
|
||||
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(561, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_1": TestConfig(
|
||||
@@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(553, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
|
||||
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(565, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(561, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_2": TestConfig(
|
||||
@@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
|
||||
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(566, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(562, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(566, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_3": TestConfig(
|
||||
@@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
|
||||
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(563, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(567, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_4": TestConfig(
|
||||
|
||||
@@ -648,11 +648,6 @@ class VllmConfig:
|
||||
"`external_launcher` distributed executor backend, but you chose "
|
||||
f"`{executor_backend}`."
|
||||
)
|
||||
if self.cache_config.mamba_cache_mode != "none":
|
||||
raise ValueError(
|
||||
"Currently, async scheduling is not compatible with "
|
||||
"prefix caching for Mamba models."
|
||||
)
|
||||
elif self.scheduler_config.async_scheduling is None:
|
||||
# Enable async scheduling unless there is an incompatible option.
|
||||
if (
|
||||
@@ -685,13 +680,6 @@ class VllmConfig:
|
||||
scope="local",
|
||||
)
|
||||
self.scheduler_config.async_scheduling = False
|
||||
elif self.cache_config.mamba_cache_mode != "none":
|
||||
logger.warning_once(
|
||||
"Async scheduling is not compatible with "
|
||||
"prefix caching for Mamba models and will be disabled.",
|
||||
scope="local",
|
||||
)
|
||||
self.scheduler_config.async_scheduling = False
|
||||
else:
|
||||
self.scheduler_config.async_scheduling = True
|
||||
|
||||
|
||||
@@ -814,6 +814,14 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
|
||||
# NOTE (tdoublep) with async scheduling, the num_computed_tokens can contain
|
||||
# draft tokens from the previous step that may or may not be rejected later.
|
||||
# This can make us think we are further ahead in the sequence than we actually
|
||||
# are, so let's assume that all tokens are rejected so we don't free blocks
|
||||
# that we might actually need.
|
||||
num_computed_tokens = max(0, num_computed_tokens - self.num_speculative_blocks)
|
||||
|
||||
super().remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
if self.mamba_cache_mode == "align":
|
||||
# `last_state_block_idx` refers to the block index allocated two steps ago.
|
||||
@@ -879,6 +887,9 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
# We can ignore lookahead tokens because current draft models don't have
|
||||
# mamba layers.
|
||||
num_tokens = num_tokens_main_model
|
||||
|
||||
# NOTE(tdouble): this is an over-estimate of how many blocks we need because
|
||||
# num_tokens can include draft tokens that will later be rejected.
|
||||
num_required_blocks = (
|
||||
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
|
||||
)
|
||||
@@ -922,6 +933,8 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
# mamba layers.
|
||||
num_tokens = num_tokens_main_model
|
||||
req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id]
|
||||
# NOTE(tdouble): this is an over-estimate of how many blocks we need because
|
||||
# num_tokens can include draft tokens that will later be rejected.
|
||||
num_required_blocks = (
|
||||
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
@@ -142,7 +143,11 @@ def preprocess_mamba(
|
||||
# if num_computed_tokens is 0, prev_state_idx will be -1
|
||||
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
|
||||
|
||||
num_blocks = len(req_state.block_ids[mamba_group_ids[0]])
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_blocks: int = (
|
||||
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
|
||||
+ num_speculative_blocks
|
||||
)
|
||||
|
||||
# We always save the current running state at the last
|
||||
# (1 + num_speculative_blocks) block.
|
||||
|
||||
Reference in New Issue
Block a user