# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses from unittest.mock import Mock import pytest import torch from vllm.config import ( CacheConfig, ECTransferConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig, ) from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange, ) from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.utils.hashing import sha256 from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, ) from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv pytestmark = pytest.mark.cpu_test def test_add_requests(): scheduler = create_scheduler() requests = create_requests(num_requests=10) for i, request in enumerate(requests): scheduler.add_request(request) assert request.request_id in scheduler.requests assert len(scheduler.waiting) == i + 1 def test_finish_request(): scheduler = create_scheduler() requests = create_requests(num_requests=10) for request in requests: scheduler.add_request(request) for i, request in enumerate(requests): scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) assert request.request_id not in scheduler.requests assert len(scheduler.waiting) == 9 - i def test_get_num_unfinished_requests(): scheduler = create_scheduler() requests = create_requests(num_requests=10) for request in requests: scheduler.add_request(request) for i, request in enumerate(requests): scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED) assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 @pytest.mark.parametrize( "enable_prefix_caching, prompt_logprobs", [ (False, None), (True, 5), ], ) def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None): """Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs """ scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs) for request in requests: scheduler.add_request(request) # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): assert num_tokens == len(requests[int(req_id)].prompt_token_ids) # Verify requests moved from waiting to running assert len(scheduler.waiting) == 0 assert len(scheduler.running) == len(requests) for i, request in enumerate(requests): assert scheduler.running[i] == request def test_schedule_multimodal_requests(): scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf") mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)] requests = create_requests( num_requests=10, num_tokens=200, mm_positions=mm_positions, ) for request in requests: scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 for req_id, num_tokens in output.num_scheduled_tokens.items(): assert num_tokens == len(requests[int(req_id)].prompt_token_ids) assert len(output.scheduled_encoder_inputs) == 10 for req_id, encoder_input in output.scheduled_encoder_inputs.items(): assert len(encoder_input) == 1 def test_async_scheduling_pp_allows_rescheduling_with_output_placeholders(): """Async scheduling + PP: allow multi-step in-flight scheduling per request""" scheduler = create_scheduler(async_scheduling=True, pipeline_parallel_size=2) (req,) = create_requests(num_requests=1, num_tokens=8) scheduler.add_request(req) _ = scheduler.schedule() assert req.num_output_placeholders > 0 # before any update_from_output, we still expect the request can be # scheduled again (multi-step in-flight). output = scheduler.schedule() assert req.request_id in output.num_scheduled_tokens def test_schedule_partial_requests(): """Test scheduling behavior with partial requests. This test verifies that: 1. The scheduler can handle multiple partial requests in a single step when constrained by encoder budget. 2. A request in RUNNING state may be unscheduled in subsequent steps if there is insufficient encoder budget. """ scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, ) mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)] requests = create_requests( num_requests=3, num_tokens=800, mm_positions=mm_positions, ) for request in requests: scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 assert scheduler.max_num_encoder_input_tokens == 1024 # The first request is scheduled fully. assert output.num_scheduled_tokens[requests[0].request_id] == 800 # The second request is scheduled partially. # The tokens are not scheduled because of the encoder budget. assert output.num_scheduled_tokens[requests[1].request_id] == 100 # The third request is also scheduled partially. # The tokens are not scheduled because of the encoder budget. assert output.num_scheduled_tokens[requests[2].request_id] == 100 req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, # Only the first request has a sampled token id because # the rest requests are still being prefilled. sampled_token_ids=[[0], [], []], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) # Schedule the next step. # Only the first and second requests are scheduled. # The third request is in the RUNNING state but not scheduled in this step # because of the encoder budget. output = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output.scheduled_new_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 2 assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[1].request_id] == 700 assert requests[2].request_id not in output.num_scheduled_tokens def test_no_mm_input_chunking(): # Disable multimodal input chunking. scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests( num_requests=1, num_tokens=1200, mm_positions=mm_positions ) for request in requests: scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # We want to only see the 400 text tokens at the start scheduled assert output.num_scheduled_tokens[requests[0].request_id] == 400 req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) output = scheduler.schedule() assert len(scheduler.running) == 1 assert len(output.scheduled_new_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 1 assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 800 # Test that we fail if we disable chunked mm input and use too small # of a max_num_batched_tokens for the mm input. with pytest.raises(ValueError): _ = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=100, disable_chunked_mm_input=True, ) @pytest.mark.parametrize("enable_prefix_caching", [True, False]) def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): """Test scheduling behavior with concurrent partial requests. This test verifies that: there are multiple long prefill requests in the RUNNING state, and we can schedule them together. """ scheduler = create_scheduler( model="facebook/opt-125m", max_num_batched_tokens=1024, long_prefill_token_threshold=400, enable_prefix_caching=enable_prefix_caching, ) requests = create_requests( num_requests=3, num_tokens=800, ) for request in requests: scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # The first request is scheduled partially - 400. assert output.num_scheduled_tokens[requests[0].request_id] == 400 # The second request is scheduled partially - 400. assert output.num_scheduled_tokens[requests[1].request_id] == 400 # The third request is also scheduled partially - 1024 - 400 - 400 = 224. assert output.num_scheduled_tokens[requests[2].request_id] == 224 req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) # Schedule the next step. All three requests are running. # Processed the remaining prefills of the first and second requests. output1 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output1.scheduled_new_reqs) == 0 assert output1.scheduled_cached_reqs.num_reqs == 3 assert len(output1.finished_req_ids) == 0 assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400 assert output1.num_scheduled_tokens[requests[2].request_id] == 224 # Schedule the third step. All three requests are running. # First and second requests are in the decode stage. # All the remaining tokens in the third request are processed. model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output1, model_runner_output) output2 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output2.scheduled_new_reqs) == 0 assert output2.scheduled_cached_reqs.num_reqs == 3 assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224 def test_stop_via_update_from_output(): """Test stopping behavior through update_from_output""" scheduler = create_scheduler(num_speculative_tokens=1) # Test case 1: Stop on EOS token requests = create_requests(num_requests=2, max_tokens=10) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [], requests[1].request_id: [10], }, num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[ [EOS_TOKEN_ID], [10, 11], ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped, second continues assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_STOPPED assert requests[0].request_id in scheduler.finished_req_ids assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] assert list(requests[1].output_token_ids) == [10, 11] # Test case 2: Stop on custom stop token scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2}, total_num_scheduled_tokens=5, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 42], requests[1].request_id: [13], }, num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped on custom token assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_STOPPED assert requests[0].stop_reason == 42 assert requests[0].request_id in scheduler.finished_req_ids assert list(requests[0].output_token_ids) == [10, 42] assert list(requests[1].output_token_ids) == [13, 14] # Test case 3: Stop on max tokens scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=2) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1}, total_num_scheduled_tokens=4, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 11], requests[1].request_id: [], }, num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped due to length assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED assert requests[0].request_id in scheduler.finished_req_ids assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=1, max_tokens=10, ignore_eos=True) requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], ) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output, model_output) # Verify request continues past EOS assert len(scheduler.running) == 1 assert not requests[0].is_finished() assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] def test_check_stop_min_tokens(): """Test that requests don't stop when min_tokens requirement isn't met.""" from vllm.v1.core.sched.utils import check_stop # Test case 1: num_output_tokens < min_tokens # Should return False (don't stop) sampling_params = SamplingParams( ignore_eos=False, max_tokens=20, min_tokens=5, ) sampling_params.update_from_generation_config({}, EOS_TOKEN_ID) request = Request( request_id="0", prompt_token_ids=[0, 1, 2], sampling_params=sampling_params, pooling_params=None, ) # Simulate having generated 3 output tokens (less than min_tokens=5) request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present result = check_stop(request, max_model_len=100) assert result is False, "Should not stop when num_output_tokens= min_tokens # Should follow normal stopping logic (stop on EOS) request.append_output_token_ids( [ 10, 11, 12, 13, 14, EOS_TOKEN_ID, ] ) # 6 tokens > min_tokens result = check_stop(request, max_model_len=100) assert result is True, "Should stop on EOS when min_tokens met" assert request.status == RequestStatus.FINISHED_STOPPED # Test case 3: min_tokens = 0, should follow normal stopping logic sampling_params_no_min = SamplingParams( ignore_eos=False, max_tokens=20, min_tokens=0, ) sampling_params_no_min.update_from_generation_config({}, EOS_TOKEN_ID) request_no_min = Request( request_id="1", prompt_token_ids=[0, 1, 2], sampling_params=sampling_params_no_min, pooling_params=None, ) request_no_min.append_output_token_ids([10, EOS_TOKEN_ID]) result = check_stop(request_no_min, max_model_len=100) assert result is True, "Should stop on EOS when min_tokens=0" assert request_no_min.status == RequestStatus.FINISHED_STOPPED # Test case 4: min_tokens > 0 with stop token (not EOS) sampling_params_stop = SamplingParams( ignore_eos=False, max_tokens=20, min_tokens=5, stop_token_ids=[42], ) sampling_params_stop.update_from_generation_config({}, EOS_TOKEN_ID) request_stop = Request( request_id="2", prompt_token_ids=[0, 1, 2], sampling_params=sampling_params_stop, pooling_params=None, ) # Only 3 output tokens, less than min_tokens=5, but has stop token request_stop.append_output_token_ids([10, 11, 42]) result = check_stop(request_stop, max_model_len=100) assert result is False, "Should not stop when num_output_tokens= min_tokens=5 result = check_stop(request_stop, max_model_len=100) assert result is True, "Should stop on stop token when min_tokens met" assert request_stop.status == RequestStatus.FINISHED_STOPPED assert request_stop.stop_reason == 42 @pytest.mark.parametrize( "enable_prefix_caching, prompt_logprobs", [ (False, None), (True, 5), ], ) def test_schedule_concurrent_batches( enable_prefix_caching: bool, prompt_logprobs: int | None ): scheduler = create_scheduler( max_num_batched_tokens=1024, max_num_seqs=2, enable_prefix_caching=enable_prefix_caching, ) requests = create_requests( num_requests=2, num_tokens=512, prompt_logprobs=prompt_logprobs, ) # Schedule the first request. scheduler.add_request(requests[0]) scheduler_output0 = scheduler.schedule() assert len(scheduler_output0.scheduled_new_reqs) == 1 assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512 # The first request is still running, so only schedule the second request. scheduler.add_request(requests[1]) scheduler_output1 = scheduler.schedule() assert len(scheduler_output1.scheduled_new_reqs) == 1 assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512 # Model output of the first request. model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output0, model_runner_output) # Schedule the next step. # The first request can be scheduled again while the second # request is still running. scheduler_output2 = scheduler.schedule() assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1 # Model output of the second request. model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output1, model_runner_output) @pytest.mark.parametrize("enable_chunked_prefill", [True, False]) def test_schedule_order(enable_chunked_prefill: bool): scheduler = create_scheduler( max_num_batched_tokens=1024, max_num_seqs=3, enable_chunked_prefill=enable_chunked_prefill, ) # long requests requests = create_requests(num_requests=2, num_tokens=800, req_ids=["1", "2"]) # short requests requests += create_requests(num_requests=2, num_tokens=10, req_ids=["3", "4"]) for request in requests: scheduler.add_request(request) scheduler_output1 = scheduler.schedule() if enable_chunked_prefill: # When enable chunked prefill, long requests will be chunked. assert len(scheduler_output1.scheduled_new_reqs) == 2 else: # When disable chunked prefill, should not skip the long requests, # and scheduling subsequent short requests in advance, # even though there is still token budgets remaining. assert len(scheduler_output1.scheduled_new_reqs) == 1 def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. scheduler = create_scheduler( max_num_batched_tokens=100, block_size=16, num_blocks=11, enable_prefix_caching=False, ) requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. scheduler.add_request(requests[0]) scheduler_output0 = scheduler.schedule() assert len(scheduler_output0.num_scheduled_tokens) == 1 assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5 # Schedule the second request while the first request is still running. # This scenario can occur in certain cases, when max_concurrent_batches > 1 # (e.g., when pipeline parallelism is used). scheduler.add_request(requests[1]) scheduler_output1 = scheduler.schedule() assert len(scheduler_output1.num_scheduled_tokens) == 1 assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5 # Get the output of the first request. model_runner_output0 = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output0, model_runner_output0) # Schedule the first request again. This will cause the preemption # of the second request because the KV cache is full. _ = scheduler.schedule() assert len(scheduler.running) == 1 assert scheduler.running[0] == requests[0] assert requests[1].status == RequestStatus.PREEMPTED model_runner_output1 = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[42]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(scheduler_output1, model_runner_output1) # The second request (that is preempted) should be updated with the # sampled token id. assert len(requests[1].output_token_ids) == 1 assert requests[1].output_token_ids[0] == 42 def test_scheduler_reset_prefix_cache(): scheduler = create_scheduler(enable_prefix_caching=True) requests = create_requests(num_requests=10) for request in requests: scheduler.add_request(request) # Initial scheduling, requests should be at the running state now _ = scheduler.schedule() # Verify requests moved from waiting to running assert len(scheduler.waiting) == 0 assert len(scheduler.running) == len(requests) for i, request in enumerate(requests): assert scheduler.running[i] == request # Reset prefix cache should fail since there are still running requests # and they are taking KV cache assert not scheduler.reset_prefix_cache() # Reset prefix cache with reset_running_requests=True. All running requests # Should be pushed back to the waiting queue and kv cache should be freed assert scheduler.reset_prefix_cache(reset_running_requests=True) # Verify requests moved from running to waiting assert len(scheduler.waiting) == len(requests) assert len(scheduler.running) == 0 for i, request in enumerate(requests): assert scheduler.waiting[i] == request # Note - these test cases mirror some of those in test_rejection_sampler.py @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence ([[]], [[5]], (0, 0, 0, [0])), # empty sequence ( [[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], (2, 6, 3, [2, 1, 0]), ), # multiple mismatches ], ) def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): """Test scheduling behavior with speculative decoding. This test verifies that: 1. Speculated tokens get scheduled correctly 2. Spec decoding stats properly count number of draft and accepted tokens """ num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i # Schedule a decode, which will also draft speculative tokens output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) assert output.total_num_scheduled_tokens == len(requests) for i in range(len(requests)): req_id = requests[i].request_id assert output.num_scheduled_tokens[req_id] == 1 assert req_id not in output.scheduled_spec_decode_tokens model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) draft_token_ids = DraftTokenIds(req_ids, spec_tokens) scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): running_req = scheduler.running[i] # The prompt token assert running_req.num_computed_tokens == 1 # The prompt token and the sampled token assert running_req.num_tokens == 2 # The prompt token, the sampled token, and the speculated tokens assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) # No draft or accepted tokens counted yet assert not engine_core_outputs or ( engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None ) # Schedule the speculated tokens for validation output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 0 # The sampled token and speculated tokens assert output.total_num_scheduled_tokens == len(requests) + sum( len(ids) for ids in spec_tokens ) for i in range(len(requests)): req_id = requests[i].request_id assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i]) if spec_tokens[i]: assert len(output.scheduled_spec_decode_tokens[req_id]) == len( spec_tokens[i] ) else: assert req_id not in output.scheduled_spec_decode_tokens model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=output_tokens, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) scheduler_stats = ( engine_core_outputs[0].scheduler_stats if engine_core_outputs else None ) if expected[0] == 0: assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is None else: assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats assert stats.num_drafts == expected[0] assert stats.num_draft_tokens == expected[1] assert stats.num_accepted_tokens == expected[2] assert stats.num_accepted_tokens_per_pos == expected[3] def test_spec_decoding_stats_empty_output(): """Test that spec decoding stats handle empty output tokens gracefully. This is a regression test for a bug where empty sampled_token_ids would cause num_accepted = len([]) - 1 = -1, leading to a ValueError when incrementing a Prometheus counter with a negative value. """ num_spec_tokens = 3 scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=1, num_tokens=1) request = requests[0] req_id = request.request_id scheduler.add_request(request) # Initial schedule (prefill) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 # Complete the prefill with a sampled token model_runner_output = ModelRunnerOutput( req_ids=[req_id], req_id_to_index={req_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) # Add draft tokens for speculation draft_token_ids = DraftTokenIds([req_id], [[1, 2, 3]]) scheduler.update_draft_token_ids(draft_token_ids) # Schedule the speculated tokens for validation output = scheduler.schedule() assert req_id in output.scheduled_spec_decode_tokens assert len(output.scheduled_spec_decode_tokens[req_id]) == 3 # Simulate empty output tokens (e.g., due to request abortion or error) # This would previously cause num_accepted = -1 and crash model_runner_output = ModelRunnerOutput( req_ids=[req_id], req_id_to_index={req_id: 0}, sampled_token_ids=[[]], # Empty output tokens logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # This should not raise an error engine_core_outputs = scheduler.update_from_output(output, model_runner_output) # Spec decoding stats should be None since no tokens were generated scheduler_stats = ( engine_core_outputs[0].scheduler_stats if engine_core_outputs else None ) assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None def test_no_spec_tokens_scheduled_for_prefill_chunks(): """Test that draft tokens are ignored for prefill chunk requests. When a request is being prefilled in chunks (chunked prefill), draft tokens from `update_draft_token_ids` should be ignored until the prefill is complete. The bug manifests when: - A prefill chunk is scheduled - Draft tokens are provided via update_draft_token_ids - The next schedule has enough budget to include spec tokens Without the fix, spec tokens would incorrectly be scheduled with the remaining prefill tokens. With the fix, draft tokens are ignored for prefill chunks. """ num_spec_tokens = 3 # Use budget of 50, with 80 token prompt: # - First chunk: 50 tokens # - Second chunk: 30 remaining + potentially 3 spec tokens = 33 # Without fix: num_scheduled_spec_tokens = 33 + 50 - 80 = 3 (BUG!) # With fix: spec_token_ids cleared, so no spec tokens scheduled scheduler = create_scheduler( num_speculative_tokens=num_spec_tokens, max_num_batched_tokens=50, enable_chunked_prefill=True, ) requests = create_requests(num_requests=1, num_tokens=80) req = requests[0] scheduler.add_request(req) # First schedule - prefill chunk (50 of 80 tokens) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert output.num_scheduled_tokens[req.request_id] == 50 # Update from output (no sampled token since still prefilling) req_to_index = {req.request_id: 0} model_runner_output = ModelRunnerOutput( req_ids=[req.request_id], req_id_to_index=req_to_index, sampled_token_ids=[[]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) # Provide draft tokens while request is still in prefill. # The fix ensures these are ignored for prefill chunks. draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]]) scheduler.update_draft_token_ids(draft_token_ids) # Second schedule - remaining 30 tokens of prefill output = scheduler.schedule() # KEY ASSERTION: Should schedule exactly the remaining 30 prefill tokens, # NOT 33 (30 + 3 spec). Without the fix, this would be 33. assert output.num_scheduled_tokens[req.request_id] == 30, ( f"Expected 30 tokens (remaining prefill only), " f"got {output.num_scheduled_tokens[req.request_id]}. " "Spec tokens should not be scheduled with prefill chunks." ) # No spec tokens should be in the output assert req.request_id not in output.scheduled_spec_decode_tokens, ( "Spec tokens should not be scheduled with prefill chunks" ) # Update from output with a sampled token (prefill complete) model_runner_output = ModelRunnerOutput( req_ids=[req.request_id], req_id_to_index=req_to_index, sampled_token_ids=[[42]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) # Now provide draft tokens - should be accepted since prefill is complete draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]]) scheduler.update_draft_token_ids(draft_token_ids) # spec_token_ids SHOULD be set after prefill is complete assert req.spec_token_ids == [1, 2, 3], ( f"spec_token_ids should be set after prefill, got {req.spec_token_ids}" ) # Third schedule - decode phase with spec tokens output = scheduler.schedule() # 1 new token + 3 spec tokens = 4 assert output.num_scheduled_tokens[req.request_id] == 4 assert req.request_id in output.scheduled_spec_decode_tokens assert len(output.scheduled_spec_decode_tokens[req.request_id]) == num_spec_tokens def _assert_right_scheduler_output( output: SchedulerOutput, num_requests: int, expected_num_scheduled_tokens: int, ): """Check if SchedulerOutput is correct after remote KV cache hit.""" # We should inject the kv_connector_metadata. assert len(output.kv_connector_metadata.requests) == num_requests # Only num_tokens - matched_num_new_tokens should be scheduled. for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): assert num_scheduled_tokens == expected_num_scheduled_tokens def _assert_right_kv_cache_manager( scheduler: Scheduler, requests: list[Request], num_tokens: int, block_size: int, num_requests: int, num_total_blocks: int, ): """Check whether KVCacheManager is correct after allocate.""" # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req in requests: blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].req_to_blocks[req.request_id] hashes = req.block_hashes assert ( scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS ) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. BLOCKS_PER_REQ = num_tokens / block_size assert ( scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == num_total_blocks - num_requests * BLOCKS_PER_REQ ) def _step_until_done( scheduler: Scheduler, output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ): """Loop over schedule(), update_from_output() until finished.""" all_finished = False _ = scheduler.update_from_output(output, model_runner_output) while not all_finished: # Schedule + a few iterations until stopping. output = scheduler.schedule() assert len(scheduler.running) for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): # We should be in the decode phase now. assert num_scheduled_tokens == 1 if scheduler.connector is not None: assert len(output.kv_connector_metadata.requests) == 0 if scheduler.ec_connector is not None: assert len(output.ec_connector_metadata.mm_datas) == 0 ecos = scheduler.update_from_output(output, model_runner_output)[0] all_done = True for eco in ecos.outputs: if eco.finish_reason is None: all_done = False all_finished = all_done def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): """Cycle requests through a KV transfer cyle.""" # Requests should first transition to WAITING_FOR_REMOTE_KVS output = scheduler.schedule() assert len(scheduler.waiting) == len(req_ids) assert len(scheduler.running) == 0 assert len(output.scheduled_new_reqs) == 0 for req in scheduler.requests.values(): assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS # No model execution yet EMPTY_OUTPUT = ModelRunnerOutput( req_ids=[], req_id_to_index={}, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) initial_ecos = scheduler.update_from_output(output, EMPTY_OUTPUT) # Simulate KV transfer completion using KVConnectorOutput.finished_recving output = scheduler.schedule() assert len(scheduler.waiting) == len(req_ids) assert len(scheduler.running) == 0 MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[], req_id_to_index={}, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], kv_connector_output=KVConnectorOutput(finished_recving=req_ids), ) scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) for req_id in req_ids: assert req_id in scheduler.finished_recving_kv_req_ids return initial_ecos @pytest.mark.parametrize("is_async", [False, True]) def test_kv_connector_basic(is_async: bool): """ Test whether Scheduler with KVConnector schedules tokens, allocates memory, and cleans up requests as expected under normal operation. """ # Setup Scheduler. BLOCK_SIZE = 16 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler = create_scheduler( enable_prefix_caching=True, use_kv_connector=mock_kv( matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async ), block_size=BLOCK_SIZE, ) NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ###################################################### # FIRST SET OF REQUESTS - External Hit Only NUM_REQUESTS = 2 NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2 MAX_TOKENS = 3 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i if is_async: _step_until_kv_transfer_finished(scheduler, req_ids) MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # Ensure ScheduleOutput is correct. output = scheduler.schedule() _assert_right_scheduler_output( output=output, num_requests=NUM_REQUESTS, # Just the incremental tokens should be scheduled. expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, ) # Ensure KVCacheManager is correct. _assert_right_kv_cache_manager( scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS ) # Continue Generation until done. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # Confirm we clean up the memory properly. assert ( scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS ) ###################################################### # SECOND SET OF REQUESTS - Local And External Hit NUM_TOKENS_PREFIX = NUM_TOKENS # We will get a local prefix cache hit for the first # NUM_TOKENS_PREFIX tokens since they are used above. NUM_TOKENS = NUM_TOKENS_PREFIX * 2 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i if is_async: _step_until_kv_transfer_finished(scheduler, req_ids) MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # We should get a local cache hit of NUM_TOKENS_PREFIX and # a remote KV cache hit of NUM_MATCHED_NEW_TOKENS. output = scheduler.schedule() _assert_right_scheduler_output( output=output, num_requests=NUM_REQUESTS, # Just the incremental tokens after local + remote cache hit. expected_num_scheduled_tokens=( NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS ), ) # Ensure KVCacheManager is correct. _assert_right_kv_cache_manager( scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS ) # Continue Generation until done. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # Confirm we clean up the memory properly. assert ( scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS ) @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("local_cache_hits", [False, True]) def test_external_prefix_cache_metrics(is_async: bool, local_cache_hits: bool): """ Verify connector prefix cache metrics are updated correctly when the scheduler processes requests with KV connector hits. """ BLOCK_SIZE = 16 if local_cache_hits: NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 # 32 tokens NUM_LOCAL_HITS = NUM_MATCHED_NEW_TOKENS * 2 # 64 tokens NUM_REQUESTS = 1 NUM_TOKENS = NUM_LOCAL_HITS * 2 # 128 tokens else: NUM_MATCHED_NEW_TOKENS = 4 NUM_LOCAL_HITS = 0 NUM_REQUESTS = 2 NUM_TOKENS = 8 # 8 tokens # Setup Scheduler. scheduler = create_scheduler( enable_prefix_caching=local_cache_hits, use_kv_connector=mock_kv( matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async ), block_size=BLOCK_SIZE, ) if local_cache_hits: # First, establish local cache by running a request to completion requests = create_requests( num_requests=1, num_tokens=NUM_LOCAL_HITS, max_tokens=2, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i if is_async: _step_until_kv_transfer_finished(scheduler, req_ids) # Run first request to completion to establish local cache output = scheduler.schedule() MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # --- Prepare test requests --- MAX_TOKENS = 2 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i initial_ecos = None if is_async: initial_ecos = _step_until_kv_transfer_finished(scheduler, req_ids) # --- Trigger scheduling and simulate model output --- output = scheduler.schedule() MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[r.request_id for r in requests], req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, sampled_token_ids=[[1000]] * NUM_REQUESTS, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # Update scheduler stats ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) # --- Assertions --- assert ecos is not None and len(ecos) > 0 assert ecos[0].scheduler_stats is not None if local_cache_hits: # For async, local cache stats come from the first step if initial_ecos: local_stats = initial_ecos[0].scheduler_stats.prefix_cache_stats else: local_stats = ecos[0].scheduler_stats.prefix_cache_stats assert local_stats is not None assert local_stats.queries == NUM_TOKENS * NUM_REQUESTS assert local_stats.hits == NUM_LOCAL_HITS * NUM_REQUESTS if initial_ecos: external_stats = initial_ecos[0].scheduler_stats.connector_prefix_cache_stats else: external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats assert external_stats is not None assert external_stats.queries == (NUM_TOKENS - NUM_LOCAL_HITS) * NUM_REQUESTS assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS assert external_stats.requests == NUM_REQUESTS assert external_stats.preempted_requests == 0 @pytest.mark.parametrize( "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] ) def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): """ Test whether scheduler with KVConnector is able to handle unable to allocate (run out of blocks in allocate_slots(). """ # Setup Scheduler With Mock External Cache Hit. BLOCK_SIZE = 4 NUM_BLOCKS = 10 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler = create_scheduler( enable_prefix_caching=True, use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, # encoder connector should not affect test results use_ec_connector=use_ec_connector, ec_role=ec_role, ) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. NUM_REQUESTS = 2 NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE MAX_TOKENS = 2 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # Just one request should be running. output = scheduler.schedule() _assert_right_scheduler_output( output, num_requests=1, expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # All memory should be freed, with one request waiting. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 # Just one request should be running. output = scheduler.schedule() _assert_right_scheduler_output( output, num_requests=1, expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # All memory should be freed, with no requests waiting / running. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 0 @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize( "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] ) def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): """ Test whether scheduler with KVConnector is able to handle unable to allocate (run out of blocks in allocate_slots(). """ # Setup Scheduler With Mock External Cache Hit. BLOCK_SIZE = 2 # NOTE: there is 1 null block, so this is 6 blocks. NUM_BLOCKS = 7 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler = create_scheduler( enable_prefix_caching=True, use_kv_connector=mock_kv( matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async ), block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, # encoder connector should not affect test results use_ec_connector=use_ec_connector, ec_role=ec_role, ) # Create two requests. # Both can be scheduled at first, but the second request # will be preempted and re-scheduled. NUM_REQUESTS = 2 NUM_TOKENS = BLOCK_SIZE * 2 + 1 MAX_TOKENS = BLOCK_SIZE * 2 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # All can be scheduled - 1st token. output = scheduler.schedule() if is_async: assert len(scheduler.waiting) == 2 assert scheduler.running == [] _step_until_kv_transfer_finished(scheduler, req_ids) output = scheduler.schedule() _assert_right_scheduler_output( output, # 2 remote kv cache hits. num_requests=2, expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, ) assert len(scheduler.running) == 2 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) # All can be scheduled - 2nd token. output = scheduler.schedule() _assert_right_scheduler_output( output, # no connector_metadata num_requests=0, expected_num_scheduled_tokens=1, ) assert len(scheduler.running) == 2 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) # This will generate a new block and cause a preemption - 3rd token. output = scheduler.schedule() _assert_right_scheduler_output( output, # no connector_metadata num_requests=0, expected_num_scheduled_tokens=1, ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # Only 1 can be scheduled - 4th (and last token). output = scheduler.schedule() _assert_right_scheduler_output( output, # no connector_metadata num_requests=0, expected_num_scheduled_tokens=1, ) assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 0 # All memory should be freed since nothing is running. assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 # Restarts the preempted request - generate 3rd token. # This will have a local and remote cache hit. output = scheduler.schedule() if is_async: waiting_req_ids = [req.request_id for req in scheduler.waiting] assert len(waiting_req_ids) == 1 _step_until_kv_transfer_finished(scheduler, waiting_req_ids) output = scheduler.schedule() _assert_right_scheduler_output( output, # 1 remote kv_cache hit! num_requests=1, # Only 1 block was preempted and there is a single # remote hit. So only single new token is scheduled. expected_num_scheduled_tokens=1, ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_new_reqs == [] _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # Only 1 can be scheduled - 4th (and last token). output = scheduler.schedule() _assert_right_scheduler_output( output, # no connector_metadata num_requests=0, expected_num_scheduled_tokens=1, ) assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_new_reqs == [] assert len(scheduler.running) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 0 # All memory should be freed since nothing is running. assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, sampled_token_ids=[[1000]] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) def assert_scheduler_empty(scheduler: Scheduler): """Confirm the scheduler is "empty" - i.e. no leaks.""" # Scheduler Metadata. assert len(scheduler.requests) == 0 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 0 assert len(scheduler.finished_req_ids) == 0 # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. assert ( len( scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks ) == 0 ) assert ( len( scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].num_cached_block ) == 0 ) num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks ) assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. for block in scheduler.kv_cache_manager.block_pool.blocks: assert block.ref_cnt == 0 # assert block._block_hash is None # assert ( # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block # ) == 0) def test_memory_leak(): """Test that we do not have a memory leak.""" scheduler = create_scheduler(enable_prefix_caching=True) NUM_REQUESTS = 5 NUM_TOKENS = 10 MAX_TOKENS = 10 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS ) # Add each request. for request in requests: scheduler.add_request(request) scheduler_output = scheduler.schedule() model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) # Iterate until done. while True: scheduler_output = scheduler.schedule() if len(scheduler.running) == 0: break model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) # Confirm no memory leak. assert_scheduler_empty(scheduler) def create_scheduler_with_priority( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, enable_prefix_caching: bool = False, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, max_model_len: int | None = None, num_speculative_tokens: int | None = None, use_ec_connector: bool = False, ec_role: str | None = None, ) -> Scheduler: """Create scheduler with priority policy enabled. Args: model: model under test max_num_seqs: max sequences to schedule max_num_batch_tokens: max num tokens to batch enable_prefix_caching: optionally force APC config (True/False) or use default (False) Returns: {class}`Scheduler` instance with priority scheduling """ model_config = ModelConfig( model=model, trust_remote_code=True, dtype="float16", seed=42, ) if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, is_encoder_decoder=model_config.is_encoder_decoder, policy="priority", # Enable priority scheduling ) # Cache config, optionally force APC cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", enable_prefix_caching=enable_prefix_caching, ) kv_transfer_config = ( KVTransferConfig( kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) if use_kv_connector else None ) speculative_config: SpeculativeConfig | None = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( model="ngram", num_speculative_tokens=num_speculative_tokens ) ec_transfer_config = ( ECTransferConfig( ec_connector="ECExampleConnector", ec_role=ec_role, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ) if use_ec_connector else None ) vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, ec_transfer_config=ec_transfer_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer"], FullAttentionSpec( block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, ), ) ], ) cache_config.num_gpu_blocks = num_blocks return Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), block_size=block_size, ) _none_hash_initialized = False def create_requests_with_priority( num_requests: int, priorities: list[int], arrival_times: list[float] | None = None, num_tokens: int = 10, mm_hashes_list: list[list[str]] | None = None, mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, stop_token_ids: list[int] | None = None, prompt_logprobs: int | None = None, starting_idx: int = 0, same_prompt: bool = False, block_size: int = 16, req_ids: list[str] | None = None, ): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests if arrival_times is not None: assert len(arrival_times) == num_requests else: arrival_times = [float(i) for i in range(num_requests)] global _none_hash_initialized if not _none_hash_initialized: init_none_hash(sha256) _none_hash_initialized = True block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams( ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs, ) sampling_params.update_from_generation_config({}, EOS_TOKEN_ID) requests = [] if mm_hashes_list is not None: # NOTE: allow manual input; some mm items can have the same identifier # no. of mm_hashes and mm_positions for each request should be identical assert mm_positions is not None, ( "mm_positions must be provided when mm_hashes_list is provided" ) assert len(mm_hashes_list) == len(mm_positions) == num_requests assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions] # Since same identifier would imply they are identical encoder output # Verify mm items with identical identifier are having mm_position.length seen_hashes: dict[str, int] = {} if req_ids: assert len(req_ids) == num_requests else: req_ids = [f"{i + starting_idx}" for i in range(num_requests)] for i in range(num_requests): mm_features = [] for j, position in enumerate( mm_positions[i] if mm_positions is not None else [] ): if mm_hashes_list is not None: identifier = mm_hashes_list[i][j] # Verify if position length is identical position_length = position.length if identifier in seen_hashes: assert seen_hashes[identifier] == position_length, ( f"mm_hash '{identifier}' has inconsistent position lengths: " f"previously {seen_hashes[identifier]}, now {position_length} " f"at request {i}, position {j}" ) else: seen_hashes[identifier] = position_length else: # Unique dummy hash for each mm item identifier = f"hash{i}_{j}" mm_feature = MultiModalFeatureSpec( data=MultiModalKwargsItem.dummy(), mm_position=position, identifier=identifier, modality="image", ) mm_features.append(mm_feature) prompt_token_ids = ( [starting_idx] * num_tokens if same_prompt else [i + starting_idx] * num_tokens ) request = Request( request_id=req_ids[i], prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, mm_features=mm_features if mm_features else None, arrival_time=arrival_times[i], priority=priorities[i], block_hasher=block_hasher, ) requests.append(request) return requests def test_priority_scheduling_basic_ordering(): """Test that requests are scheduled in priority order (lower value = higher priority).""" scheduler = create_scheduler_with_priority() # Create requests with different priorities # Priority 0 (highest), 1, 2 (lowest) priorities = [2, 0, 1] # Add in non-priority order arrival_times = [1.0, 2.0, 3.0] # All different arrival times requests = create_requests_with_priority( num_requests=3, priorities=priorities, arrival_times=arrival_times ) # Add requests in non-priority order for request in requests: scheduler.add_request(request) # Schedule and verify priority order output = scheduler.schedule() # Should schedule all requests since they fit in budget assert len(output.scheduled_new_reqs) == 3 # Verify they are scheduled in priority order: # req_1 (priority 0), req_2 (priority 1), req_0 (priority 2) scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] assert scheduled_req_ids == ["1", "2", "0"] def test_priority_scheduling_arrival_time_tiebreaker(): """Test that arrival time is used as tiebreaker when priorities are equal.""" scheduler = create_scheduler_with_priority() # Create requests with same priority but different arrival times priorities = [1, 1, 1] # All same priority arrival_times = [3.0, 1.0, 2.0] # Different arrival times requests = create_requests_with_priority( num_requests=3, priorities=priorities, arrival_times=arrival_times ) # Add requests in non-arrival order for request in requests: scheduler.add_request(request) # Schedule and verify arrival time order output = scheduler.schedule() # Should schedule all requests since they fit in budget assert len(output.scheduled_new_reqs) == 3 # Verify they are scheduled in arrival time order: # req_1 (1.0), req_2 (2.0), req_0 (3.0) scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] assert scheduled_req_ids == ["1", "2", "0"] def test_priority_scheduling_mixed_priority_and_arrival(): """Test priority scheduling with mixed priorities and arrival times.""" scheduler = create_scheduler_with_priority() # Create requests with mixed priorities and arrival times priorities = [2, 1, 1, 0] # Mixed priorities arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times requests = create_requests_with_priority( num_requests=4, priorities=priorities, arrival_times=arrival_times ) # Add requests for request in requests: scheduler.add_request(request) # Schedule and verify order output = scheduler.schedule() # Should schedule all requests since they fit in budget assert len(output.scheduled_new_reqs) == 4 # Expected order: # 1. req_3 (priority 0, arrival 4.0) # 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1 # 3. req_1 (priority 1, arrival 3.0) # 4. req_0 (priority 2, arrival 1.0) scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] assert scheduled_req_ids == ["3", "2", "1", "0"] # This test had previously been passing due to its use of duplicate # request ids which resulted in incorrect behavior. # Now that the duplicate req ids had been fixed it fails and # investigation is needed into whether the priority scheduling # preemption logic is working as designed or not. @pytest.mark.skip("needs investigation") def test_priority_scheduling_preemption(): """Test that priority scheduling preempts lower priority requests when memory is constrained.""" # Create scheduler with very limited memory to force preemption scheduler = create_scheduler_with_priority( max_num_seqs=3, # Allow multiple requests max_num_batched_tokens=200, num_blocks=6, # Very limited blocks to force memory pressure block_size=16, # Standard block size ) # Create initial low-priority requests that will consume most memory low_priority_requests = create_requests_with_priority( num_requests=2, priorities=[5, 5], # Low priority arrival_times=[1.0, 2.0], num_tokens=30, # Large enough to consume significant memory, req_ids=["lo1", "lo2"], ) # Add and schedule low priority requests for request in low_priority_requests: scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 2 # Simulate model execution to move requests to running state model_output = ModelRunnerOutput( req_ids=[req.request_id for req in low_priority_requests], req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # Verify both requests are running assert len(scheduler.running) == 2 # Now add a high-priority request that requires memory allocation # This should trigger preemption due to memory constraints high_priority_request = create_requests_with_priority( num_requests=1, priorities=[0], # High priority arrival_times=[3.0], num_tokens=30, # Large enough to require significant memory req_ids=["hi1"], )[0] scheduler.add_request(high_priority_request) # Schedule again - this should trigger # preemption when trying to allocate memory output = scheduler.schedule() # Due to the scheduler's design, if preemption happens # during running request scheduling, # waiting requests won't be scheduled in the same step # Let's check if preemption occurred by looking at the waiting queue # If preemption happened, we should see requests in the # waiting queue if len(scheduler.waiting) > 1: # high priority + preempted request # Preemption occurred - verify the high priority request # gets scheduled next output2 = scheduler.schedule() assert len(output2.scheduled_new_reqs) == 1 # High priority request assert output2.scheduled_new_reqs[0].req_id == "hi1" else: # No preemption needed - all requests fit # This is also valid behavior if memory allows assert len(output.scheduled_new_reqs) == 1 # High priority request assert output.scheduled_new_reqs[0].req_id == "hi1" def test_priority_scheduling_no_preemption_when_space_available(): """Test that preemption doesn't happen when there's space for new requests.""" scheduler = create_scheduler_with_priority( max_num_seqs=3, # Allow 3 concurrent requests max_num_batched_tokens=200, # Sufficient token budget ) # Add two low-priority running requests low_priority_requests = create_requests_with_priority( num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30, req_ids=["lo1", "lo2"], ) for request in low_priority_requests: scheduler.add_request(request) output = scheduler.schedule() model_output = ModelRunnerOutput( req_ids=[req.request_id for req in low_priority_requests], req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # Add high-priority request high_priority_request = create_requests_with_priority( num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30, req_ids=["hi1"], )[0] scheduler.add_request(high_priority_request) # Schedule - should not preempt since there's space output = scheduler.schedule() # Should schedule the new request without preemption assert len(output.scheduled_new_reqs) == 1 assert len(scheduler.running) == 3 # All three requests running assert len(scheduler.waiting) == 0 # No requests waiting def test_priority_scheduling_preemption_victim_selection(): """Test that the correct victim is selected for preemption based on priority and arrival time.""" # This test verifies the priority-based victim selection logic # by checking the waiting queue order after adding requests with different # priorities scheduler = create_scheduler_with_priority( max_num_seqs=1, # Force sequential processing to test priority order ) # Create requests with different priorities requests = create_requests_with_priority( num_requests=3, priorities=[3, 2, 0], # Different priorities: low, medium, high arrival_times=[1.0, 2.0, 3.0], num_tokens=10, ) # Add all requests for request in requests: scheduler.add_request(request) # Schedule - should only schedule the highest priority request # (req_2, priority 0) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority # Verify the waiting queue has the remaining requests in priority order assert len(scheduler.waiting) == 2 # Extract waiting requests and verify priority order waiting_requests = list(scheduler.waiting) waiting_priorities = [req.priority for req in waiting_requests] waiting_req_ids = [req.request_id for req in waiting_requests] # Should be req_1 (priority 2) then req_0 (priority 3) assert waiting_priorities == [2, 3] assert waiting_req_ids == ["1", "0"] def test_priority_scheduling_equal_priority_preemption(): """Test arrival time tiebreaker when requests have equal priority.""" # This test verifies that arrival time is used as a tiebreaker for equal # priorities scheduler = create_scheduler_with_priority( max_num_seqs=1, # Force sequential processing ) # Create requests with same priority but different arrival times requests = create_requests_with_priority( num_requests=3, priorities=[2, 2, 2], # Same priority arrival_times=[3.0, 1.0, 2.0], # Different arrival times num_tokens=10, ) # Add all requests for request in requests: scheduler.add_request(request) # Schedule - should schedule the request with earliest arrival time output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0) # Verify the waiting queue has remaining requests in arrival time order assert len(scheduler.waiting) == 2 # Extract waiting requests and verify arrival time order waiting_requests = list(scheduler.waiting) waiting_arrival_times = [req.arrival_time for req in waiting_requests] waiting_req_ids = [req.request_id for req in waiting_requests] # Should be req_2 (arrival 2.0) then req_0 (arrival 3.0) assert waiting_arrival_times == [2.0, 3.0] assert waiting_req_ids == ["2", "0"] def test_priority_scheduling_waiting_queue_order(): """Test that the waiting queue maintains priority order.""" scheduler = create_scheduler_with_priority( max_num_seqs=1, # Only one request can run at a time ) # Create multiple requests with different priorities requests = create_requests_with_priority( num_requests=4, priorities=[3, 1, 2, 0], # Mixed priorities arrival_times=[1.0, 2.0, 3.0, 4.0], num_tokens=10, ) # Add all requests for request in requests: scheduler.add_request(request) # Schedule - should only schedule the highest priority request # (req_3, priority 0) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert output.scheduled_new_reqs[0].req_id == "3" # Verify waiting queue has remaining requests in priority order assert len(scheduler.waiting) == 3 # Extract requests from waiting queue # (it's a heap, so we need to pop to see order) waiting_requests = list(scheduler.waiting) waiting_priorities = [req.priority for req in waiting_requests] waiting_req_ids = [req.request_id for req in waiting_requests] # Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3) assert waiting_req_ids == ["1", "2", "0"] assert waiting_priorities == [1, 2, 3] def test_priority_scheduling_fcfs_fallback(): """Test that FCFS behavior is maintained when all requests have same priority.""" scheduler = create_scheduler_with_priority() # Create requests with same priority but different arrival times priorities = [1, 1, 1, 1] # All same priority arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times requests = create_requests_with_priority( num_requests=4, priorities=priorities, arrival_times=arrival_times ) # Add requests for request in requests: scheduler.add_request(request) # Schedule output = scheduler.schedule() # Should schedule all requests in arrival time order assert len(output.scheduled_new_reqs) == 4 scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] # Expected order by arrival time: # req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0) assert scheduled_req_ids == ["1", "3", "2", "0"] def test_priority_scheduling_with_limited_slots(): """Test priority scheduling when max_num_seqs limits concurrent requests.""" scheduler = create_scheduler_with_priority( max_num_seqs=2, # Only allow 2 concurrent requests max_num_batched_tokens=1000, # Plenty of token budget ) # Create requests with different priorities requests = create_requests_with_priority( num_requests=4, priorities=[3, 1, 2, 0], # Mixed priorities arrival_times=[1.0, 2.0, 3.0, 4.0], num_tokens=10, ) # Add all requests for request in requests: scheduler.add_request(request) # Schedule - should only schedule the 2 highest priority requests output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 2 # Should schedule req_3 (priority 0) and req_1 (priority 1) scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] assert "3" in scheduled_req_ids # Priority 0 assert "1" in scheduled_req_ids # Priority 1 # Remaining requests should be in waiting queue in priority order assert len(scheduler.waiting) == 2 # Extract waiting requests and verify order waiting_requests = list(scheduler.waiting) waiting_priorities = [req.priority for req in waiting_requests] waiting_req_ids = [req.request_id for req in waiting_requests] # Should be req_2 (priority 2) then req_0 (priority 3) assert waiting_priorities == [2, 3] assert waiting_req_ids == ["2", "0"] def test_priority_scheduling_heap_property(): """Test that the waiting queue maintains heap property for priority scheduling.""" scheduler = create_scheduler_with_priority( max_num_seqs=1, # Only one request can run at a time ) # Add requests in random priority order priorities = [5, 1, 8, 3, 2, 7, 4, 6] arrival_times = [float(i) for i in range(len(priorities))] requests = create_requests_with_priority( num_requests=len(priorities), priorities=priorities, arrival_times=arrival_times, num_tokens=10, ) # Add all requests for request in requests: scheduler.add_request(request) # Schedule one request at a time and verify priority order scheduled_priorities = [] while scheduler.waiting: output = scheduler.schedule() if output.scheduled_new_reqs: req = output.scheduled_new_reqs[0] scheduled_priorities.append(requests[int(req.req_id)].priority) # Simulate completion to make room for next request model_output = ModelRunnerOutput( req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, sampled_token_ids=[[100]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # Finish the request to make room for the next one scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED) # Verify requests were scheduled in priority order (lowest value first) expected_priorities = sorted(priorities) assert scheduled_priorities == expected_priorities def test_schedule_skip_tokenizer_init(): scheduler = create_scheduler(skip_tokenizer_init=True) requests = create_requests(num_requests=5) for request in requests: scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) def test_schedule_skip_tokenizer_init_structured_output_request(): scheduler = create_scheduler(skip_tokenizer_init=True) structured_outputs_params = StructuredOutputsParams(regex="[0-9]+") sampling_params = SamplingParams( ignore_eos=False, max_tokens=16, structured_outputs=structured_outputs_params, ) sampling_params.update_from_generation_config({}, EOS_TOKEN_ID) request = Request( request_id="0", prompt_token_ids=[0, 1], mm_features=None, sampling_params=sampling_params, pooling_params=None, ) scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @pytest.mark.parametrize( "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] ) def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( use_ec_connector, ec_role ): """Test that priority scheduling preempts lower priority requests when out of KV cache space.""" # Create scheduler with very limited memory to force preemption scheduler = create_scheduler_with_priority( max_num_seqs=2, # Allow multiple requests max_num_batched_tokens=200, num_blocks=5, # Can hold 64 tokens (first block is null) block_size=16, # Standard block size use_kv_connector=True, # encoder connector should not affect test results use_ec_connector=use_ec_connector, ec_role=ec_role, ) # Create a request and schedule it request_low = create_requests_with_priority( num_requests=1, priorities=[1], arrival_times=[0.0], num_tokens=30, starting_idx=0, )[0] scheduler.add_request(request_low) # 1st schedule output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 1 # Simulate model execution - 1st decode model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, sampled_token_ids=[[100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # Create a high priority request and schedule it request_high = create_requests_with_priority( num_requests=1, priorities=[0], arrival_times=[1.0], num_tokens=32, starting_idx=1, )[0] scheduler.add_request(request_high) # 2nd schedule output = scheduler.schedule() # KV cache should be full at this point assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 assert len(output.scheduled_new_reqs) == 1 assert output.scheduled_cached_reqs.num_reqs == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 2 # Simulate model execution - 2nd decode requests = [request_low, request_high] model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[100] for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # 3rd schedule - this should trigger preemption # req_low needs 32 tokens = 2 blocks # req_high needs 33 tokens = 3 blocks # so doesn't fit in 4 blocks. output = scheduler.schedule() # Should have preempted req_low assert len(output.scheduled_new_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 # Simulate model execution - 3rd decode model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[], [100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # Finish the requests to make room for the preempted requests to resume scheduler.update_from_output(output, model_output) scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED) # 4th Schedule - this should trigger the resumption output = scheduler.schedule() scheduled_cached_reqs = output.scheduled_cached_reqs assert len(output.scheduled_new_reqs) == 0 assert scheduled_cached_reqs.num_reqs == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 1 # Preempted request resumed in scheduled_cached_reqs assert len(scheduled_cached_reqs.resumed_req_ids) == 1 assert len(scheduled_cached_reqs.all_token_ids) == 1 assert scheduled_cached_reqs.req_ids[0] == request_low.request_id assert request_low.request_id in scheduled_cached_reqs.resumed_req_ids assert request_low.request_id in scheduled_cached_reqs.all_token_ids # Resumed tokens include 30 prompt tokens and 2 decoded tokens assert len(scheduled_cached_reqs.all_token_ids[request_low.request_id]) == 32 assert scheduled_cached_reqs.all_token_ids[request_low.request_id][31] == 100 @pytest.mark.parametrize( ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), [ (True, False, True), (False, False, False), # Encoder-decoder models should always have it disabled (False, True, False), (True, True, False), ], ) def test_chunked_prefill_disabled_for_encoder_decoder( enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool ) -> None: """Validate that chunked prefill is appropriately disabled for encoder-decoder models.""" scheduler_config = SchedulerConfig( enable_chunked_prefill=enable_chunked_prefill, is_encoder_decoder=is_encoder_decoder, # Must <= max_num_batched_tokens if chunked prefill is disabled max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) # `is_encoder_decoder` should only be used during construction # of the config, and otherwise stored in the model config. assert "is_encoder_decoder" not in vars(scheduler_config) assert "is_encoder_decoder" not in [ f.name for f in dataclasses.fields(scheduler_config) ] _validate_chunked_prefill_settings_for_encoder_decoder( scheduler_config, is_encoder_decoder, expect_enabled ) # Ensure it is retained in VllmConfig, even after its post-init. vllm_config = VllmConfig(scheduler_config=scheduler_config) _validate_chunked_prefill_settings_for_encoder_decoder( vllm_config.scheduler_config, is_encoder_decoder, expect_enabled ) def _validate_chunked_prefill_settings_for_encoder_decoder( scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool ) -> None: """Validate chunked prefill settings in the scheduler config for encoder-decoder models.""" assert scheduler_config.enable_chunked_prefill is expect_enabled if is_encoder_decoder: # Encoder-decoder models should automatically disable chunked multimodal # inputs as well assert scheduler_config.disable_chunked_mm_input is not expect_enabled if is_encoder_decoder and not expect_enabled: assert scheduler_config.long_prefill_token_threshold == 0 # ============================================================================== # EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests start # NOTE: In E->P->D disagg case, both KV and EC Connector works in P instance # Unless specify, the existence of KV Connector should not affect any test results # ============================================================================== def _assert_right_encoder_cache_allocated( scheduler: Scheduler, hashes_to_check: list[str] | None = None, requests: list[Request] | None = None, expected_total_allocated: int | None = None, ): """Check whether encoder cache is allocated correctly.""" encoder_cache_manager = scheduler.encoder_cache_manager # Verify encoder cache manager exists assert encoder_cache_manager is not None, "Encoder cache manager should exist" # Verify number of cache if expected_total_allocated is not None: assert len(encoder_cache_manager.cached) == expected_total_allocated if expected_total_allocated == 0: return # Verify each request with MM data is in cache cached_hashes = set(encoder_cache_manager.cached.keys()) if hashes_to_check: missed_hashes = set(hashes_to_check) - cached_hashes assert not missed_hashes, ( f"Miss hashes: {missed_hashes} " f"Existing encoder cache: {encoder_cache_manager.cached}" ) for req in requests if requests is not None else []: if req.mm_features: mm_hashes = [f.identifier for f in req.mm_features] req_hashes = set(mm_hashes) # unique hashes set missed_hashes = req_hashes - cached_hashes assert not missed_hashes, ( f"Miss hashes in cache for request {req.request_id}: {missed_hashes} " f"Existing encoder cache: {encoder_cache_manager.cached}" ) def _assert_right_ec_connector_metadata( output: SchedulerOutput, mm_features_list: list[MultiModalFeatureSpec], ): """Verify that ECConnector metadata EXACTLY matches the input MM data""" # Get the connector metadata metadata = output.ec_connector_metadata # Create lookup dictionaries for efficient access metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas} # Check all required identifiers exist in metadata; and no extra # In ECExampleConnector format # NOTE: even having same identifier, the mm_features can be different # since their mm_position can be in different offsets, etc identifiers_dict = {f.identifier for f in mm_features_list} assert set(metadata_dict.keys()) == identifiers_dict # Verify the info matches for i, mm_feature in enumerate(mm_features_list): identifier = mm_feature.identifier assert metadata_dict[identifier].mm_hash == identifier assert metadata_dict[identifier].num_token == mm_feature.mm_position.length def _assert_right_encoder_inputs( output: SchedulerOutput, check_exist: bool | None = True, requests: list[Request] | None = None, expected_encoder_inputs: list[list[int]] | None = None, expected_total_reqs: int | None = None, ): """Verify that requests/mm_hashes should (not) in scheduled encoder input If check_exist is False, this function returns True if requests are NOT in encoder inputs""" # Get the scheduled encoder inputs # NOTE: scheduled_encoder_inputs is a dictionary with request id as key scheduled_encoder_inputs = output.scheduled_encoder_inputs # Check if scheduled_encoder_inputs is empty as expected if expected_total_reqs is not None: assert len(scheduled_encoder_inputs) == expected_total_reqs if expected_total_reqs == 0: return # Number of expected enocder inputs should match number of requests if expected_encoder_inputs: assert check_exist and requests is not None # only support expect input exist assert len(requests) == len(expected_encoder_inputs) # Check request (not) exist as expected for i, request in enumerate(requests if requests is not None else []): assert (request.request_id in scheduled_encoder_inputs) is check_exist, ( f"Request {request.id} presence mismatch: expected {check_exist}, " f"got {request.id in scheduled_encoder_inputs}" ) if expected_encoder_inputs: scheduled_encoder_input = scheduled_encoder_inputs[request.request_id] assert scheduled_encoder_input == expected_encoder_inputs[i] def test_scheduler_no_ec_connector_by_default(): """Test scheduler doesn't have EC connector by default.""" scheduler = create_scheduler() assert scheduler.ec_connector is None @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_text_only_request(use_kv_connector): """Test text-only requests don't allocate encoder cache.""" scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", use_kv_connector=use_kv_connector, use_ec_connector=True, ec_role="ec_consumer", ) NUM_PROMPT_TOKENS = 100 # Create text-only request (no mm_positions) requests = create_requests( num_requests=1, num_tokens=NUM_PROMPT_TOKENS, ) assert not requests[0].mm_features # No MM data scheduler.add_request(requests[0]) output = scheduler.schedule() # Should schedule assert len(output.scheduled_new_reqs) == 1 # Scheduled tokens should equal prompt tokens exactly scheduled = output.num_scheduled_tokens[requests[0].request_id] assert scheduled == NUM_PROMPT_TOKENS, ( f"Text-only should schedule {NUM_PROMPT_TOKENS}, got {scheduled}" ) # Encoder cache should be empty _assert_right_encoder_cache_allocated(scheduler, expected_total_allocated=0) # ECConnector should carry no metadata _assert_right_ec_connector_metadata(output, mm_features_list=[]) # Scheduled encoder input should be empty; no mm to compute _assert_right_encoder_inputs(output, expected_total_reqs=0) @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_cache_hit_external_load(use_kv_connector): """Test ec_consumer loads from external cache when hit. A normal basic operation for EPD disaggrgation""" scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", enable_prefix_caching=True, # kv connector should not effect test results use_kv_connector=use_kv_connector, use_ec_connector=True, ec_role="ec_consumer", ) # Create MM request NUM_TOKENS = 200 # NOTE: includes mm tokens NUM_ENCODER_TOKENS = 100 mm_hashes_list = [["hash_test1"]] mm_positions = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS)]] request = create_requests( num_requests=1, num_tokens=NUM_TOKENS, mm_hashes_list=mm_hashes_list, mm_positions=mm_positions, )[0] # Mock cache hit - encoder cache has_exists externally scheduler.ec_connector.has_cache_item = Mock(return_value=True) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) scheduler.add_request(request) output = scheduler.schedule() # Should schedule prompt tokens scheduled_tokens = output.num_scheduled_tokens[request.request_id] assert scheduled_tokens == NUM_TOKENS # Should called update_state_after_alloc for external load scheduler.ec_connector.update_state_after_alloc.assert_called_with(request, 0) # Encoder cache should contain mm items from request _assert_right_encoder_cache_allocated(scheduler, requests=[request]) # ECConnector should carry metadata of request _assert_right_ec_connector_metadata(output, mm_features_list=request.mm_features) # Scheduled encoder input should be empty; no mm to compute _assert_right_encoder_inputs(output, expected_total_reqs=0) @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_cache_miss_computes_locally(use_kv_connector): """Test consumer can compute encoder locally when cache miss (fallback).""" # encoder cache itself if it doesn't receive it from external storage scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", enable_prefix_caching=True, use_kv_connector=use_kv_connector, use_ec_connector=True, ec_role="ec_consumer", ) # Verify consumer role assert scheduler.ec_connector is not None assert not scheduler.ec_connector.is_producer # Create MM request request_mm_missed = create_requests( num_requests=1, num_tokens=200, # Total (including 100 MM) mm_positions=[[PlaceholderRange(offset=0, length=100)]], # 100 MM tokens )[0] # Mock cache miss - encoder cache doesn't exist externally scheduler.ec_connector.has_cache_item = Mock(return_value=False) scheduler.add_request(request_mm_missed) output = scheduler.schedule() # SCHEDULER should decide to compute encoder locally (fallback) assert len(output.scheduled_new_reqs) == 1 # Should schedule full prompt tokens scheduled_tokens = output.num_scheduled_tokens[request_mm_missed.request_id] assert scheduled_tokens == 200, ( f"Expected 200 tokens on cache miss, got {scheduled_tokens}" ) # Encoder cache should contain mm items from request _assert_right_encoder_cache_allocated(scheduler, requests=[request_mm_missed]) # ECConnector should carry no metadata (missed cache) _assert_right_ec_connector_metadata(output, mm_features_list=[]) # Scheduled encoder input contain mm for request_mm_missed _assert_right_encoder_inputs( output, requests=[request_mm_missed], expected_encoder_inputs=[[0]], # index 0 of the mm item expected_total_reqs=1, ) # Then MODEL_RUNNER will execute the encoder and cache the result @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): """Test consumer with partial cache hit (local & connector) with 2 requests.""" scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", enable_prefix_caching=True, use_kv_connector=use_kv_connector, use_ec_connector=True, ec_role="ec_consumer", ) # Create MM request NUM_TOKENS_1 = 300 # NOTE: includes mm tokens NUM_ENCODER_TOKENS_1 = 50 mm_hashes_list_1 = [["hash1_A", "hash1_B", "hash1_A", "hash1_F"]] mm_positions_1 = [ [ PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1), PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_1), PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_1), PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1), ] ] has_cache_item_result_map_1 = {"hash1_A": False, "hash1_B": True, "hash1_F": True} # Create request with 4 MM items, with 2 identical items request1 = create_requests( num_requests=1, num_tokens=NUM_TOKENS_1, mm_hashes_list=mm_hashes_list_1, mm_positions=mm_positions_1, max_tokens=1, # For simplicity )[0] # Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist scheduler.ec_connector.has_cache_item = Mock( side_effect=lambda hash_val: has_cache_item_result_map_1[hash_val] ) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) scheduler.add_request(request1) output = scheduler.schedule() # Should schedule all tokens scheduled_tokens = output.num_scheduled_tokens[request1.request_id] assert scheduled_tokens == NUM_TOKENS_1 # Encoder cache should contain all mm items from request _assert_right_encoder_cache_allocated(scheduler, requests=[request1]) # Should have called update_state_after_alloc for external load scheduler.ec_connector.update_state_after_alloc.assert_called() scheduler.ec_connector.update_state_after_alloc.reset_mock() # ECConnector should carry metadata for 2nd and 4th mm item _assert_right_ec_connector_metadata( output, mm_features_list=[request1.mm_features[1], request1.mm_features[3]] ) # Should schedule ONLY 1 encoder input (index 0), no repeat for identical items _assert_right_encoder_inputs( output, requests=[request1], expected_encoder_inputs=[[0]], # index 0 of the mm item ONLY expected_total_reqs=1, ) # Simulate model execution 1 step model_output = ModelRunnerOutput( req_ids=[request1.request_id], req_id_to_index={request1.request_id: 0}, sampled_token_ids=[[100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # request1 is finished after outputing 1 token # Finish request scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED) # Create another request with 4 MM items NUM_TOKENS_2 = 400 NUM_ENCODER_TOKENS_2 = 50 mm_hashes_list_2 = [["hash1_C", "hash1_D", "hash1_E", "hash1_A"]] mm_positions_2 = [ [ PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2), PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_2), PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_2), PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2), ] ] has_cache_item_result_map_2 = { "hash1_C": True, "hash1_D": False, "hash1_E": False, "hash1_A": True, } request2 = create_requests( num_requests=1, num_tokens=NUM_TOKENS_2, mm_hashes_list=mm_hashes_list_2, mm_positions=mm_positions_2, max_tokens=1, # For simplicity )[0] # Mock partial cache hit: only hash1_A and hash1_C exist in connector scheduler.ec_connector.has_cache_item = Mock( side_effect=lambda hash_val: has_cache_item_result_map_2[hash_val] ) scheduler.add_request(request2) output = scheduler.schedule() # Check # Should schedule all tokens scheduled_tokens = output.num_scheduled_tokens[request2.request_id] assert scheduled_tokens == 400 # Encoder cache should contain all mm items from request2 _assert_right_encoder_cache_allocated(scheduler, requests=[request2]) # Should call update_state_after_alloc for hash1_C, ONLY # hash1_A should not be loaded from connector # since it's computed in last request & exist in local cache # Order of getting encoder cache should be: local cache -> connector-> compute scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 0) scheduler.ec_connector.update_state_after_alloc.assert_called_once() scheduler.ec_connector.update_state_after_alloc.reset_mock() # ECConnector should carry metadata for hash1_C only (index 0) _assert_right_ec_connector_metadata( output, mm_features_list=[request2.mm_features[0]] ) # Should schedule 2 encoder input hash1_D and hash1_E (index 1, 2) _assert_right_encoder_inputs( output, requests=[request2], expected_encoder_inputs=[[1, 2]], expected_total_reqs=1, ) @pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"]) @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector): scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_seqs=10, # allow multiple requests max_num_batched_tokens=2048, enable_prefix_caching=True, use_kv_connector=use_kv_connector, use_ec_connector=True, ec_role="ec_consumer", ) mm_hashes_list = [[f"hash_{i}"] for i in range(10)] mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)] requests = create_requests( num_requests=10, num_tokens=200, mm_hashes_list=mm_hashes_list, mm_positions=mm_positions, ) for request in requests: scheduler.add_request(request) # Set up to test different encoder cache exsistence scenario after preemption # Order of getting encoder cache should be: local cache -> connector-> compute scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) if cache_exist == "local": # Allocate cache to cache manager manually to mimick for req in requests: scheduler.encoder_cache_manager.allocate(req, 0) else: # Make sure local encoder cache empty scheduler.encoder_cache_manager.cached = {} if cache_exist == "connector_only": # Cache exist in ec_connector scheduler.ec_connector.has_cache_item = Mock(return_value=True) elif cache_exist == "no_where": scheduler.ec_connector.has_cache_item = Mock(return_value=False) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 for req_id, num_tokens in output.num_scheduled_tokens.items(): assert num_tokens == len(requests[int(req_id)].prompt_token_ids) ## Encoder-cache-specific checks: # mm_hashes of requests exist in cache after scheduling for all scenario _assert_right_encoder_cache_allocated(scheduler, requests=requests) # Should only call update_state_after_alloc when loaded externally if cache_exist == "connector_only": scheduler.ec_connector.update_state_after_alloc.assert_called_with( requests[-1], 0 ) # Concat mm_features for the 10 requests together mm_features_list = [feature for req in requests for feature in req.mm_features] # Check metadata should contain mm data for all 10 requests _assert_right_ec_connector_metadata(output, mm_features_list=mm_features_list) else: scheduler.ec_connector.update_state_after_alloc.assert_not_called() # ECConnector should carry no metadata _assert_right_ec_connector_metadata(output, mm_features_list=[]) scheduler.ec_connector.update_state_after_alloc.reset_mock() # Should only schedule encoder input when cache is not found anywhere if cache_exist == "no_where": _assert_right_encoder_inputs( output, requests=requests, expected_encoder_inputs=[[0] for _ in range(10)], expected_total_reqs=10, ) else: _assert_right_encoder_inputs(output, expected_total_reqs=0) @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_unable_to_allocate(use_kv_connector): """ Test whether scheduler with ECConnector is able to handle unable to allocate (run out of blocks). """ # Setup Scheduler With Mock External Cache Hit. BLOCK_SIZE = 4 NUM_BLOCKS = 10 scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", enable_prefix_caching=True, use_kv_connector=use_kv_connector, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, use_ec_connector=True, ec_role="ec_consumer", ) # Mock ec_connector load external cache behavior scheduler.ec_connector.has_cache_item = Mock(return_value=True) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. NUM_REQUESTS = 2 NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE MAX_TOKENS = 2 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, mm_hashes_list=[["hash_1"], ["hash_2"]], mm_positions=[ [PlaceholderRange(offset=1, length=10)] for _ in range(NUM_REQUESTS) ], max_tokens=MAX_TOKENS, block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) req_to_index[request.request_id] = i # Setup MODEL_RUNNER_OUTPUT to be run in _step_until_done later MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # Just one request should be running. output = scheduler.schedule() scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id] assert scheduled_tokens == NUM_TOKENS assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # Should have called update_state_after_alloc for external load scheduler.ec_connector.update_state_after_alloc.assert_called_with( scheduler.running[0], 0 ) scheduler.ec_connector.update_state_after_alloc.reset_mock() # All memory should be freed, with one request waiting. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 # Just one request should be running. output = scheduler.schedule() scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id] assert scheduled_tokens == NUM_TOKENS assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # update_state_after_alloc should be called for loading external cache scheduler.ec_connector.update_state_after_alloc.assert_called_with( scheduler.running[0], 0 ) scheduler.ec_connector.update_state_after_alloc.reset_mock() # All memory should be freed, with no requests waiting / running. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 0 @pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"]) @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_priority_scheduling_ec_connector_preemption_and_resumption( cache_exist, use_kv_connector ): """Test that priority scheduling preempts lower priority requests when out of KV cache space.""" # Create scheduler with very limited memory to force preemption scheduler = create_scheduler_with_priority( model="llava-hf/llava-1.5-7b-hf", enable_prefix_caching=True, max_num_seqs=2, # allow multiple requests # kv connector should not effect test results use_kv_connector=use_kv_connector, num_blocks=15, # can hold 244 tokens with 14 blocks (first block is null) block_size=16, # standard block size use_ec_connector=True, ec_role="ec_consumer", ) # Mock cache hit: Both cache exist in connector (at E->PD initially) scheduler.ec_connector.has_cache_item = Mock(return_value=True) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) # Create a request and schedule it (and to be preempted) request_low = create_requests_with_priority( num_requests=1, priorities=[1], arrival_times=[0.0], num_tokens=94, mm_hashes_list=[["hash_low"]], # NOTE: this test only preempt the last block. # Setting mm_position at the last block can force to recompute encoding mm_positions=[[PlaceholderRange(offset=82, length=10)]], starting_idx=0, )[0] scheduler.add_request(request_low) # 1st schedule output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 scheduled_tokens = output.num_scheduled_tokens[request_low.request_id] assert scheduled_tokens == 94 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 1 ## Encoder-cache-specific checks: # Encoder cache should contain mm items from request _assert_right_encoder_cache_allocated(scheduler, requests=[request_low]) # Verify update_state_after_alloc called (external load) scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_low, 0) scheduler.ec_connector.update_state_after_alloc.reset_mock() # ECConnector should carry metadata of request _assert_right_ec_connector_metadata( output, mm_features_list=request_low.mm_features ) # Scheduled encoder input should be empty; no mm to compute _assert_right_encoder_inputs(output, expected_total_reqs=0) # Simulate model execution - 1st decode model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, sampled_token_ids=[[100]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # Create a high priority request and schedule it request_high = create_requests_with_priority( num_requests=1, priorities=[0], arrival_times=[1.0], num_tokens=128, mm_hashes_list=[["hash_high"]], mm_positions=[[PlaceholderRange(offset=1, length=10)]], max_tokens=2, starting_idx=1, )[0] scheduler.add_request(request_high) # 2nd schedule output = scheduler.schedule() # KV cache should be full at this point assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 assert len(output.scheduled_new_reqs) == 1 assert output.scheduled_cached_reqs.num_reqs == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 2 ## Encoder-cache-specific checks: # Encoder cache should contain mm items from request _assert_right_encoder_cache_allocated(scheduler, requests=[request_high]) # Verify update_state_after_alloc called (external load) scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_high, 0) scheduler.ec_connector.update_state_after_alloc.reset_mock() # ECConnector should carry metadata of request _assert_right_ec_connector_metadata( output, mm_features_list=request_high.mm_features ) # Scheduled encoder input should be empty; no mm to compute _assert_right_encoder_inputs(output, expected_total_reqs=0) # Simulate model execution - 2nd decode requests = [request_low, request_high] model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[100] for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # 3rd schedule - - this should trigger preemption # req_low needs 96 tokens = 6 blocks # req_high needs 129 tokens = 9 blocks # so doesn't fit in 14 blocks. output = scheduler.schedule() # Should have preempted req_low assert len(output.scheduled_new_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 ## Encoder-cache-specific checks: # request_high is in decode phase now # ECConnector should carry no metadata _assert_right_ec_connector_metadata(output, mm_features_list=[]) # Scheduled encoder input should be empty; no mm to compute _assert_right_encoder_inputs(output, expected_total_reqs=0) # Simulate model execution - 3rd decode, after req_low was preempted requests = [request_low, request_high] model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[100], [100, 200]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) # Finish the requests to make room for the preempted requests to resume # req_high is finished after outputing 2 tokens scheduler.update_from_output(output, model_output) scheduler.finish_requests( request_high.request_id, RequestStatus.FINISHED_LENGTH_CAPPED ) # Set up to test different encoder cache exsistence scenario after preemption # Order of getting encoder cache should be: local cache -> connector-> compute # By default, the cache should still exist in local in this test case if cache_exist != "local": # Make local encoder cache empty scheduler.encoder_cache_manager.cached = {} if cache_exist == "connector_only": # Cache exist in ec_connector scheduler.ec_connector.has_cache_item = Mock(return_value=True) elif cache_exist == "no_where": scheduler.ec_connector.has_cache_item = Mock(return_value=False) # 4th Schedule - this should trigger req_low resumption from waiting output = scheduler.schedule() scheduled_cached_reqs = output.scheduled_cached_reqs assert len(output.scheduled_new_reqs) == 0 assert scheduled_cached_reqs.num_reqs == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 1 # Preempted request resumed in scheduled_cached_reqs assert len(scheduled_cached_reqs.resumed_req_ids) == 1 assert len(scheduled_cached_reqs.all_token_ids) == 1 assert scheduled_cached_reqs.req_ids[0] == request_low.request_id assert request_low.request_id in scheduled_cached_reqs.resumed_req_ids assert request_low.request_id in scheduled_cached_reqs.all_token_ids ## Resumed tokens include 94 prompt tokens and 2 decoded tokens assert len(scheduled_cached_reqs.all_token_ids[request_low.request_id]) == 96 assert scheduled_cached_reqs.all_token_ids[request_low.request_id][95] == 100 assert scheduler.running[0].request_id == request_low.request_id assert request_high.request_id in output.finished_req_ids ## Encoder-cache-specific checks: # mm_hash of request_low exists in cache after scheduling for all scenario _assert_right_encoder_cache_allocated(scheduler, requests=[request_low]) # Should only call update_state_after_alloc when loaded externally if cache_exist == "connector_only": scheduler.ec_connector.update_state_after_alloc.assert_called_with( request_low, 0 ) _assert_right_ec_connector_metadata( output, mm_features_list=request_low.mm_features ) else: scheduler.ec_connector.update_state_after_alloc.assert_not_called() # ECConnector should carry no metadata _assert_right_ec_connector_metadata(output, mm_features_list=[]) scheduler.ec_connector.update_state_after_alloc.reset_mock() # Should only schedule encoder input when cache is not found anywhere if cache_exist == "no_where": _assert_right_encoder_inputs( output, requests=[request_low], expected_encoder_inputs=[[0]], expected_total_reqs=1, ) else: _assert_right_encoder_inputs(output, expected_total_reqs=0) @pytest.mark.parametrize("use_kv_connector", [False, True]) def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector): """ Scenario: - Encoder cache size: 32 - Request A: 1 feature (12 tokens) → NOT cached remotely. - Request B: 3 features (3 x 10 tokens) → ALL cached remotely. Steps: 1. Schedule Request A (locally uses 12 tokens). 2. Schedule Request B (remote cache) - only schedule 1st and 2nd 3. Free A's cache, then schedule B again (continuation) - schedule 3rd image """ scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, enable_prefix_caching=True, use_kv_connector=use_kv_connector, block_size=16, num_blocks=11, # Can hold 160 tokens (first block is null) use_ec_connector=True, ec_role="ec_consumer", ) # Limit the number of availiable slots of EncoderCacheManager scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32) # Create MM request1 NUM_TOKENS_1 = 50 # NOTE: includes mm tokens NUM_ENCODER_TOKENS_1 = 12 mm_hashes_list_1 = [["hash1_1"]] mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]] request1 = create_requests( num_requests=1, num_tokens=NUM_TOKENS_1, mm_hashes_list=mm_hashes_list_1, mm_positions=mm_positions_1, max_tokens=1, # For simplicity req_ids=["req1"], )[0] # Create MM request1 with 3 MM items NUM_TOKENS_2 = 40 NUM_ENCODER_TOKENS_2 = 10 mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]] mm_positions_2 = [ [ PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2), PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2), PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2), ] ] request2 = create_requests( num_requests=1, num_tokens=NUM_TOKENS_2, mm_hashes_list=mm_hashes_list_2, mm_positions=mm_positions_2, max_tokens=10, req_ids=["req2"], )[0] # Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely scheduler.ec_connector.has_cache_item = Mock( side_effect=lambda hash_value: hash_value in mm_hashes_list_2[0] ) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) scheduler.add_request(request1) scheduler.add_request(request2) output = scheduler.schedule() # Now, since encoder cache manager can only store 32 tokens # It should allocated mm item hash1_1, hash2_1 and hash2_2 scheduled_tokens = output.num_scheduled_tokens[request1.request_id] assert scheduled_tokens == NUM_TOKENS_1 assert scheduler.get_num_unfinished_requests() == 2 # Encoder cache should contain mm item from request1 _assert_right_encoder_cache_allocated( scheduler, hashes_to_check=["hash1_1", "hash2_1", "hash2_2"] ) # request2's 2nd mm item is the last call of update_state_after_alloc scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1) scheduler.ec_connector.update_state_after_alloc.reset_mock() # ECConnector should carry metadata of hash2_1 and hash2_2 ONLY _assert_right_ec_connector_metadata( output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]] ) # Should schedule ONLY 1 encoder input _assert_right_encoder_inputs( output, requests=[request1], expected_encoder_inputs=[[0]], # index 0 of the mm item of request1 expected_total_reqs=1, ) # Simulate model execution 1 step model_output = ModelRunnerOutput( req_ids=[request1.request_id, request2.request_id], req_id_to_index={request1.request_id: 0, request2.request_id: 1}, sampled_token_ids=[[100], [121]], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) # request1 is finished after outputing 1 token # Finish request scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED) assert scheduler.get_num_unfinished_requests() == 1 # Schedule again; Now request1's encoder cache should be freed # -> hash2_3 can be scheduled and allocated output = scheduler.schedule() # Check # Should schedule all tokens scheduled_tokens = output.num_scheduled_tokens[request2.request_id] print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}") print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}") # Encoder cache should contain all mm items from request2 _assert_right_encoder_cache_allocated(scheduler, requests=[request2]) # request2's 3rd mm item is the ONLY call of update_state_after_alloc scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2) scheduler.ec_connector.update_state_after_alloc.assert_called_once() scheduler.ec_connector.update_state_after_alloc.reset_mock() # ECConnector should carry metadata for hash2_3 ONLY _assert_right_ec_connector_metadata( output, mm_features_list=[request2.mm_features[2]] ) # Should schedule no encoder input _assert_right_encoder_inputs( output, expected_total_reqs=0, ) # ============================================================================== # EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end # ============================================================================== def test_prepend_skipped_requests_order(): scheduler = create_scheduler(max_num_seqs=1, use_kv_connector=True) requests = create_requests(num_requests=4) for request in requests: scheduler.add_request(request) # 4 requests waiting, capture their order expected_waiting_reqs = list(scheduler.waiting) # simulate first 2 waiting requests are waiting for remote KVs for req in expected_waiting_reqs[:2]: req.status = RequestStatus.WAITING_FOR_REMOTE_KVS # schedule step # expect the first 2 waiting to be skipped, the third running, # and the fourth waiting scheduler.schedule() # pop the third request which is expected to be running expected_waiting_reqs.pop(2) # verify waiting order is preserved assert list(scheduler.waiting) == expected_waiting_reqs def test_abort_request_waiting_for_remote_kvs(): scheduler = create_scheduler(use_kv_connector=True) # add a single request request = create_requests(num_requests=1)[0] scheduler.add_request(request) # set request to waiting for remote KVs, and abort it request.status = RequestStatus.WAITING_FOR_REMOTE_KVS scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED) assert request.status == RequestStatus.FINISHED_ABORTED # verify request is not deleted assert request.request_id in scheduler.requests # finish recving request scheduler_output = scheduler.schedule() model_runner_output = ModelRunnerOutput( req_ids=[], req_id_to_index={}, kv_connector_output=KVConnectorOutput(finished_recving={request.request_id}), ) scheduler.update_from_output(scheduler_output, model_runner_output) # assert request is deleted assert request.request_id not in scheduler.requests assert not scheduler.finished_recving_kv_req_ids def test_abort_request_finished_recving(): scheduler = create_scheduler(use_kv_connector=True) # add a single request request = create_requests(num_requests=1)[0] scheduler.add_request(request) # set request to waiting for remote KVs, finished but not yet updated request.status = RequestStatus.WAITING_FOR_REMOTE_KVS scheduler.finished_recving_kv_req_ids.add(request.request_id) # abort request scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED) assert request.status == RequestStatus.FINISHED_ABORTED # verify request is deleted assert request.request_id not in scheduler.requests assert not scheduler.finished_recving_kv_req_ids # ============================================================================== # Variable-length encoder cross-attention block allocation tests # ============================================================================== def _create_encoder_decoder_scheduler( block_size: int = 16, num_blocks: int = 10000, max_num_batched_tokens: int = 8192, max_num_seqs: int = 16, ) -> Scheduler: """Create a scheduler configured for encoder-decoder cross-attention block allocation testing. Constructs a scheduler with both FullAttentionSpec (self-attention) and CrossAttentionSpec (cross-attention) KV cache groups, then patches it to behave as an encoder-decoder model. """ from vllm.v1.core.encoder_cache_manager import EncoderDecoderCacheManager from vllm.v1.kv_cache_interface import CrossAttentionSpec model_config = ModelConfig( model="facebook/opt-125m", trust_remote_code=True, dtype="float16", seed=42, ) scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_num_batched_tokens, # is_encoder_decoder disables chunked prefill and prefix caching is_encoder_decoder=True, ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", enable_prefix_caching=False, ) cache_config.num_gpu_blocks = num_blocks vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, ) # KV cache config with both self-attention and cross-attention groups, # mirroring an encoder-decoder model like Whisper. kv_cache_config = KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["self_attn_layer"], FullAttentionSpec( block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, ), ), KVCacheGroupSpec( ["cross_attn_layer"], CrossAttentionSpec( block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, ), ), ], ) # Construct the scheduler. Since opt-125m is not truly encoder-decoder, # the __init__ won't set up encoder-decoder internals. We patch them # after construction. scheduler = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, block_size=block_size, structured_output_manager=StructuredOutputManager(vllm_config), ) # Patch to enable encoder-decoder behavior in the scheduling loop. scheduler.is_encoder_decoder = True scheduler.max_num_encoder_input_tokens = max_num_batched_tokens scheduler.encoder_cache_manager = EncoderDecoderCacheManager( cache_size=max_num_batched_tokens ) return scheduler def _get_num_cross_attn_blocks(scheduler: Scheduler, request_id: str) -> int: """Get the number of cross-attention blocks allocated for a request.""" from vllm.v1.core.single_type_kv_cache_manager import CrossAttentionManager coordinator = scheduler.kv_cache_manager.coordinator for manager in coordinator.single_type_managers: if isinstance(manager, CrossAttentionManager): blocks = manager.req_to_blocks.get(request_id, []) return len(blocks) raise AssertionError("No CrossAttentionManager found in coordinator") def test_variable_length_cross_attn_block_allocation(): """Test that cross-attention blocks are allocated per-request based on actual encoder input length, not a fixed maximum. Fixed max-encoder-length allocation would assign `ceil(max_encoder_tokens / block_size)` blocks to every request whereas with dynamic allocation, exactly `ceil(actual_encoder_tokens / block_size)` blocks are assigned to each request. """ block_size = 16 scheduler = _create_encoder_decoder_scheduler(block_size=block_size) # Create requests with distinctly different encoder input lengths, # simulating variable-length audio inputs to a model like Whisper. encoder_lengths = [500, 1000, 200] num_prompt_tokens = 100 # Decoder prompt tokens requests = [] for i, enc_len in enumerate(encoder_lengths): req = create_requests( num_requests=1, num_tokens=num_prompt_tokens, mm_hashes_list=[[f"enc_hash_{i}"]], mm_positions=[[PlaceholderRange(offset=0, length=enc_len)]], req_ids=[f"req_{i}"], )[0] requests.append(req) # Add and schedule all requests. for req in requests: scheduler.add_request(req) output = scheduler.schedule() # All requests should be scheduled. assert len(output.scheduled_new_reqs) == len(requests) # Verify cross-attention blocks per request match the actual encoder length. from math import ceil for req, enc_len in zip(requests, encoder_lengths): expected_blocks = ceil(enc_len / block_size) actual_blocks = _get_num_cross_attn_blocks(scheduler, req.request_id) assert actual_blocks == expected_blocks, ( f"Request {req.request_id} with {enc_len} encoder tokens: " f"expected {expected_blocks} cross-attn blocks, " f"got {actual_blocks}" ) # Verify that different encoder lengths produce different block counts, # confirming variable-length (not fixed-max) allocation. block_counts = [ _get_num_cross_attn_blocks(scheduler, req.request_id) for req in requests ] assert len(set(block_counts)) > 1, ( "All requests have the same number of cross-attn blocks, " "suggesting static max-based allocation instead of per-request" ) def test_cross_attn_blocks_not_over_allocated(): """Test that cross-attention blocks are not over-allocated compared to what each request actually needs.""" from math import ceil block_size = 16 max_encoder_tokens = 1500 # e.g., Whisper's max mel-spectrogram length scheduler = _create_encoder_decoder_scheduler(block_size=block_size) # Request with a small encoder input (much less than the max). small_enc_len = 200 request = create_requests( num_requests=1, num_tokens=100, mm_hashes_list=[["enc_small"]], mm_positions=[[PlaceholderRange(offset=0, length=small_enc_len)]], req_ids=["req_small"], )[0] scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id) expected_blocks = ceil(small_enc_len / block_size) max_blocks = ceil(max_encoder_tokens / block_size) # Blocks should match the actual encoder length. assert actual_blocks == expected_blocks, ( f"Expected {expected_blocks} blocks for {small_enc_len} encoder tokens, " f"got {actual_blocks}" ) # Blocks should be strictly less than what max-based allocation would give. assert actual_blocks < max_blocks, ( f"Cross-attn blocks ({actual_blocks}) should be less than max " f"({max_blocks}), indicating no over-allocation" ) def test_cross_attn_blocks_not_under_allocated(): """Test that cross-attention blocks are sufficient for each request's actual encoder input length. Every encoder token must have a slot. Tests various edge cases including exact block boundaries, off-by-one, and the minimum/maximum encoder input sizes. """ from math import ceil block_size = 16 # Test various encoder lengths including edge cases around block boundaries. test_cases = [ 1, # Minimum: single encoder token block_size - 1, # Just under one full block block_size, # Exactly one full block block_size + 1, # Just over one block (needs 2 blocks) block_size * 10, # Exact multiple of block size block_size * 10 + 1, # One over exact multiple 1500, # Whisper's typical max ] for enc_len in test_cases: scheduler = _create_encoder_decoder_scheduler(block_size=block_size) request = create_requests( num_requests=1, num_tokens=100, mm_hashes_list=[[f"enc_{enc_len}"]], mm_positions=[[PlaceholderRange(offset=0, length=enc_len)]], req_ids=[f"req_{enc_len}"], )[0] scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id) expected_blocks = ceil(enc_len / block_size) # Number of blocks must be exactly ceil(enc_len / block_size). assert actual_blocks == expected_blocks, ( f"Encoder length {enc_len}: expected {expected_blocks} blocks, " f"got {actual_blocks}" ) # Total available slots must be >= encoder tokens (no under-allocation). total_slots = actual_blocks * block_size assert total_slots >= enc_len, ( f"Encoder length {enc_len}: total slots {total_slots} < " f"needed {enc_len} (under-allocation)" ) def test_cross_attn_zero_blocks_without_encoder_inputs(): """Test that requests without encoder inputs get zero cross-attention blocks, even when the scheduler is configured for encoder-decoder.""" block_size = 16 scheduler = _create_encoder_decoder_scheduler(block_size=block_size) # Create a text-only request (no mm_features). request = create_requests( num_requests=1, num_tokens=100, req_ids=["req_text_only"], )[0] # Text-only request has no encoder inputs. assert not request.has_encoder_inputs scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 # No cross-attention blocks should be allocated. actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id) assert actual_blocks == 0, ( f"Text-only request should have 0 cross-attn blocks, got {actual_blocks}" ) def test_eagle3_mm_encoder_cache_with_shift(): """Test EAGLE3 encoder scheduling accounts for shift_computed_tokens. Regression test for issue #32469: When EAGLE3 is enabled with disable_chunked_mm_input=True, ensure encoder inputs are scheduled when tokens overlap the MM range, properly accounting for shift_computed_tokens in the boundary calculation. Without the fix, the scheduler would fail to schedule encoder inputs at the boundary, causing "Encoder cache miss" errors. """ scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, max_model_len=2048, num_speculative_tokens=4, # This enables EAGLE with shift=1 ) mm_start_pos = 100 mm_length = 576 mm_positions = [ [PlaceholderRange(offset=mm_start_pos, length=mm_length)], ] requests = create_requests( num_requests=1, num_tokens=mm_start_pos + mm_length + 100, mm_positions=mm_positions, ) # Start with some tokens already computed to simulate decoding request = requests[0] request.num_computed_tokens = 0 scheduler.add_request(request) output = scheduler.schedule() assert output is not None shift_computed_tokens = 1 req_id = request.request_id assert req_id in output.num_scheduled_tokens num_scheduled = output.num_scheduled_tokens[req_id] mm_feature = request.mm_features[0] start_pos = mm_feature.mm_position.offset tokens_end = request.num_computed_tokens + num_scheduled scheduled_end_with_shift = tokens_end + shift_computed_tokens # Assert that we scheduled into the MM range (test setup verification) assert scheduled_end_with_shift > start_pos, ( f"Test setup error: expected to schedule into MM range. " f"scheduled_end_with_shift={scheduled_end_with_shift}, " f"start_pos={start_pos}" ) # The key assertion: when scheduled tokens overlap MM range # (accounting for EAGLE's shift), encoder MUST be scheduled. # Without the fix, this would fail at the boundary case. assert req_id in output.scheduled_encoder_inputs, ( f"Encoder input missing: scheduled {num_scheduled} tokens " f"(computed={request.num_computed_tokens}, end={tokens_end}, " f"shifted_end={scheduled_end_with_shift}) overlapping MM at " f"{start_pos}. The fix must schedule encoder inputs." )