# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest.mock import MagicMock, patch from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.worker.mamba_utils import preprocess_mamba def _make_scheduler_output( finished_req_ids: set[str], preempted_req_ids: set[str] | None, resumed_req_ids: set[str], ) -> SchedulerOutput: cached = CachedRequestData.make_empty() cached.resumed_req_ids = resumed_req_ids return SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=cached, num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=[], finished_req_ids=finished_req_ids, free_encoder_mm_hashes=[], preempted_req_ids=preempted_req_ids, ) def test_resumed_req_ids_cleared_from_mamba_state_idx(): """When a request is force-preempted (e.g. reset_prefix_cache), it appears in resumed_req_ids but NOT in preempted_req_ids. preprocess_mamba must still clear its mamba_state_idx entry, otherwise stale indices can point beyond the new block allocation. """ spec = MagicMock(block_size=64, num_speculative_blocks=0) cache_config = MagicMock(enable_prefix_caching=True) input_batch = MagicMock(req_ids=[]) copy_bufs = MagicMock(mamba_group_ids=[0], mamba_spec=spec) mamba_state_idx = { "finished": 1, "preempted": 2, "resumed": 3, # only in resumed_req_ids, NOT in preempted "keep": 99, } sched = _make_scheduler_output( finished_req_ids={"finished"}, preempted_req_ids={"preempted"}, resumed_req_ids={"resumed"}, ) with patch( "vllm.v1.worker.mamba_utils.get_mamba_groups", return_value=([0], spec), ): preprocess_mamba( sched, MagicMock(), cache_config, mamba_state_idx, input_batch, {}, {}, (), copy_bufs, ) assert mamba_state_idx == {"keep": 99}