diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 6a7369ad3..38cfdcdb3 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput: ), logprobs_tensors=None, ) - num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1 accpeted_tokens = prompt_token_ids[ first_token_id_index : first_token_id_index + min(num_accepted_tokens, logits.shape[0]) ] - sampled_token_ids = accpeted_tokens + [-1] * ( - num_sampled_tokens - len(accpeted_tokens) - ) + sampled_token_ids = accpeted_tokens return SamplerOutput( sampled_token_ids=torch.tensor( [sampled_token_ids], device="cuda", dtype=torch.int32 @@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn(): first_token_id_index : first_token_id_index + num_speculative_tokens ] ] - return proposed_draft_token_ids + + next_token_ids = torch.tensor( + prompt_token_ids[ + first_token_id_index - 1 : first_token_id_index + - 1 + + num_accepted_tokens + ], + device="cuda", + dtype=torch.int32, + ) + + valid_sampled_tokens_count = torch.tensor( + [num_accepted_tokens], device="cuda", dtype=torch.int32 + ) + + self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) + + return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32) return fake_propose_draft_token_ids_fn @@ -184,6 +198,7 @@ mamba_kv_cache_dict = {} def get_fake_execute_model_fn(original_execute_model_fn: Callable): last_num_computed_tokens = 0 + num_prompt_tokens = None def fake_execute_model_fn( self: GPUModelRunner, @@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable): mamba_group_id ].layer_names[0] nonlocal last_num_computed_tokens + nonlocal num_prompt_tokens + + if ( + len(scheduler_output.scheduled_new_reqs) > 0 + and scheduler_output.scheduled_new_reqs[0].prompt_token_ids is not None + ): + # record number of prompt tokens + num_prompt_tokens = len( + scheduler_output.scheduled_new_reqs[0].prompt_token_ids + ) + if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0: num_computed_tokens = ( scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ) + if ( + self.num_spec_tokens + and num_prompt_tokens is not None + and num_computed_tokens > num_prompt_tokens + ): + # NOTE (tdoublep) with async scheduling, the scheduler does not have an + # accurate measure of the number of computed tokens; we need to subtract + # the number of reject tokens from the previous timestep. + num_computed_tokens -= num_speculative_tokens + 1 - num_accepted_tokens if ( num_computed_tokens // BLOCK_SIZE > last_num_computed_tokens // BLOCK_SIZE @@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(555, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(557, 4, [], (0, 1), (-1, -1)), StepAction(558, 4, [], (-1, -1), (-1, -1)), StepAction(559, 4, [], (-1, -1), (1, 0)), StepAction(560, 4, [], (-1, -1), (-1, -1)), @@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)), + StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(558, 4, [], (1, 1), (2, 0)), StepAction(560, 4, [], (-1, -1), (-1, -1)), StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(555, 4, [], (-1, -1), (-1, -1)), StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)), StepAction(559, 4, [], (-1, -1), (1, 0)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(561, 4, [], (-1, -1), (-1, -1)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), "accept_3_1": TestConfig( @@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(553, 4, [], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)), - StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(559, 4, [], (2, 1), (1, 0)), + StepAction(562, 4, [], (-1, -1), (-1, -1)), + StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), "accept_3_2": TestConfig( @@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(555, 4, [], (-1, -1), (-1, -1)), StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(561, 4, [], (-1, -1), (-1, -1)), + StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), "accept_4_1": TestConfig( @@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(553, 4, [], (-1, -1), (-1, -1)), StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(565, 4, [], (-1, -1), (-1, -1)), + StepAction(561, 4, [], (-1, -1), (-1, -1)), + StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), "accept_4_2": TestConfig( @@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(554, 4, [], (-1, -1), (-1, -1)), StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)), - StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(566, 4, [], (-1, -1), (-1, -1)), + StepAction(562, 4, [], (-1, -1), (-1, -1)), + StepAction(566, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), "accept_4_3": TestConfig( @@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(555, 4, [], (-1, -1), (-1, -1)), StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)), - StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(563, 4, [], (-1, -1), (-1, -1)), + StepAction(567, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), "accept_4_4": TestConfig( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 095809d54..d6f1202e5 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -648,11 +648,6 @@ class VllmConfig: "`external_launcher` distributed executor backend, but you chose " f"`{executor_backend}`." ) - if self.cache_config.mamba_cache_mode != "none": - raise ValueError( - "Currently, async scheduling is not compatible with " - "prefix caching for Mamba models." - ) elif self.scheduler_config.async_scheduling is None: # Enable async scheduling unless there is an incompatible option. if ( @@ -685,13 +680,6 @@ class VllmConfig: scope="local", ) self.scheduler_config.async_scheduling = False - elif self.cache_config.mamba_cache_mode != "none": - logger.warning_once( - "Async scheduling is not compatible with " - "prefix caching for Mamba models and will be disabled.", - scope="local", - ) - self.scheduler_config.async_scheduling = False else: self.scheduler_config.async_scheduling = True diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8e5edff2f..c071ae155 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -814,6 +814,14 @@ class MambaManager(SingleTypeKVCacheManager): def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) + + # NOTE (tdoublep) with async scheduling, the num_computed_tokens can contain + # draft tokens from the previous step that may or may not be rejected later. + # This can make us think we are further ahead in the sequence than we actually + # are, so let's assume that all tokens are rejected so we don't free blocks + # that we might actually need. + num_computed_tokens = max(0, num_computed_tokens - self.num_speculative_blocks) + super().remove_skipped_blocks(request_id, num_computed_tokens) if self.mamba_cache_mode == "align": # `last_state_block_idx` refers to the block index allocated two steps ago. @@ -879,6 +887,9 @@ class MambaManager(SingleTypeKVCacheManager): # We can ignore lookahead tokens because current draft models don't have # mamba layers. num_tokens = num_tokens_main_model + + # NOTE(tdouble): this is an over-estimate of how many blocks we need because + # num_tokens can include draft tokens that will later be rejected. num_required_blocks = ( cdiv(num_tokens, self.block_size) + self.num_speculative_blocks ) @@ -922,6 +933,8 @@ class MambaManager(SingleTypeKVCacheManager): # mamba layers. num_tokens = num_tokens_main_model req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + # NOTE(tdouble): this is an over-estimate of how many blocks we need because + # num_tokens can include draft tokens that will later be rejected. num_required_blocks = ( cdiv(num_tokens, self.block_size) + self.num_speculative_blocks ) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 56fb02380..a22b0eeb0 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, ) from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -142,7 +143,11 @@ def preprocess_mamba( # if num_computed_tokens is 0, prev_state_idx will be -1 prev_state_idx = (req_state.num_computed_tokens - 1) // block_size - num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_blocks: int = ( + cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size) + + num_speculative_blocks + ) # We always save the current running state at the last # (1 + num_speculative_blocks) block.