diff --git a/tests/v1/worker/test_mamba_utils.py b/tests/v1/worker/test_mamba_utils.py new file mode 100644 index 000000000..38eb250fb --- /dev/null +++ b/tests/v1/worker/test_mamba_utils.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.worker.mamba_utils import preprocess_mamba + + +def _make_scheduler_output( + finished_req_ids: set[str], + preempted_req_ids: set[str] | None, + resumed_req_ids: set[str], +) -> SchedulerOutput: + cached = CachedRequestData.make_empty() + cached.resumed_req_ids = resumed_req_ids + return SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached, + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=finished_req_ids, + free_encoder_mm_hashes=[], + preempted_req_ids=preempted_req_ids, + ) + + +def test_resumed_req_ids_cleared_from_mamba_state_idx(): + """When a request is force-preempted (e.g. reset_prefix_cache), + it appears in resumed_req_ids but NOT in preempted_req_ids. + preprocess_mamba must still clear its mamba_state_idx entry, + otherwise stale indices can point beyond the new block allocation. + """ + spec = MagicMock(block_size=64, num_speculative_blocks=0) + cache_config = MagicMock(enable_prefix_caching=True) + input_batch = MagicMock(req_ids=[]) + + mamba_state_idx = { + "finished": 1, + "preempted": 2, + "resumed": 3, # only in resumed_req_ids, NOT in preempted + "keep": 99, + } + sched = _make_scheduler_output( + finished_req_ids={"finished"}, + preempted_req_ids={"preempted"}, + resumed_req_ids={"resumed"}, + ) + + with patch( + "vllm.v1.worker.mamba_utils.get_mamba_groups", + return_value=([0], spec), + ): + preprocess_mamba( + sched, + MagicMock(), + cache_config, + mamba_state_idx, + input_batch, + {}, + {}, + (), + ) + + assert mamba_state_idx == {"keep": 99} diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index a22b0eeb0..4f8a3bd05 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -129,7 +129,13 @@ def preprocess_mamba( block_size = mamba_spec.block_size finished_req_ids = scheduler_output.finished_req_ids preempted_req_ids = scheduler_output.preempted_req_ids or set() - for req_id in itertools.chain(finished_req_ids, preempted_req_ids): + # We need to clear mamba_state_idx for resumed requests. When requests are + # force-preempted (e.g., during reset_prefix_cache / KV cache flush), + # they appear in resumed_req_ids without a corresponding entry in + # preempted_req_ids, leaving stale mamba_state_idx entries that can + # point to block indices beyond the new (smaller) block allocation. + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): mamba_state_idx.pop(req_id, None) src_state_list: list[int] = []