2025-02-04 02:43:58 -08:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2025-09-30 10:17:49 +02:00
|
|
|
import dataclasses
|
2025-04-17 15:22:40 -05:00
|
|
|
from unittest.mock import Mock
|
2025-02-04 02:43:58 -08:00
|
|
|
|
2025-03-07 20:48:12 -05:00
|
|
|
import pytest
|
2025-04-01 15:33:17 +08:00
|
|
|
import torch
|
2025-03-07 20:48:12 -05:00
|
|
|
|
2025-04-17 15:22:40 -05:00
|
|
|
from vllm.config import (
|
|
|
|
|
CacheConfig,
|
2025-11-12 10:58:33 +08:00
|
|
|
ECTransferConfig,
|
2025-04-17 15:22:40 -05:00
|
|
|
KVTransferConfig,
|
|
|
|
|
ModelConfig,
|
2025-04-24 16:57:40 +01:00
|
|
|
SchedulerConfig,
|
|
|
|
|
SpeculativeConfig,
|
|
|
|
|
VllmConfig,
|
|
|
|
|
)
|
2025-08-29 03:36:57 -07:00
|
|
|
from vllm.multimodal.inputs import (
|
|
|
|
|
MultiModalFeatureSpec,
|
|
|
|
|
MultiModalKwargsItem,
|
|
|
|
|
PlaceholderRange,
|
|
|
|
|
)
|
2025-09-18 05:20:27 -04:00
|
|
|
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
2025-11-12 10:58:33 +08:00
|
|
|
from vllm.utils.hashing import sha256
|
|
|
|
|
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
|
|
|
|
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
2025-06-30 09:07:50 -07:00
|
|
|
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
2025-03-20 17:50:43 -07:00
|
|
|
from vllm.v1.core.sched.scheduler import Scheduler
|
2025-04-01 15:33:17 +08:00
|
|
|
from vllm.v1.kv_cache_interface import (
|
|
|
|
|
FullAttentionSpec,
|
|
|
|
|
KVCacheConfig,
|
|
|
|
|
KVCacheGroupSpec,
|
|
|
|
|
)
|
2025-11-13 23:30:59 +00:00
|
|
|
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
2025-02-04 02:43:58 -08:00
|
|
|
from vllm.v1.request import Request, RequestStatus
|
2025-03-07 10:19:11 -05:00
|
|
|
from vllm.v1.structured_output import StructuredOutputManager
|
2025-02-04 02:43:58 -08:00
|
|
|
|
2025-11-13 23:30:59 +00:00
|
|
|
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv
|
2025-02-04 02:43:58 -08:00
|
|
|
|
2025-09-30 09:45:20 -04:00
|
|
|
pytestmark = pytest.mark.cpu_test
|
|
|
|
|
|
2025-02-04 02:43:58 -08:00
|
|
|
|
|
|
|
|
def test_add_requests():
|
|
|
|
|
scheduler = create_scheduler()
|
|
|
|
|
requests = create_requests(num_requests=10)
|
|
|
|
|
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
assert request.request_id in scheduler.requests
|
|
|
|
|
assert len(scheduler.waiting) == i + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_finish_request():
|
|
|
|
|
scheduler = create_scheduler()
|
|
|
|
|
requests = create_requests(num_requests=10)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
|
|
|
|
|
assert request.request_id not in scheduler.requests
|
|
|
|
|
assert len(scheduler.waiting) == 9 - i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_num_unfinished_requests():
|
|
|
|
|
scheduler = create_scheduler()
|
|
|
|
|
requests = create_requests(num_requests=10)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED)
|
|
|
|
|
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
|
|
|
|
|
|
|
|
|
|
2025-03-07 20:48:12 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"enable_prefix_caching, prompt_logprobs",
|
|
|
|
|
[
|
2025-11-22 19:34:15 +08:00
|
|
|
(False, None),
|
2025-03-07 20:48:12 -05:00
|
|
|
(True, 5),
|
|
|
|
|
],
|
|
|
|
|
)
|
2025-11-22 19:34:15 +08:00
|
|
|
def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None):
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test scheduling.
|
2025-03-07 20:48:12 -05:00
|
|
|
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
2025-10-05 15:06:22 +01:00
|
|
|
"""
|
2025-03-07 20:48:12 -05:00
|
|
|
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
|
|
|
|
requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs)
|
2025-02-04 02:43:58 -08:00
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Test initial scheduling
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
2025-02-04 02:43:58 -08:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
# Verify all requests are scheduled.
|
|
|
|
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
|
|
|
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
|
|
|
|
|
|
|
|
|
# Verify requests moved from waiting to running
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == len(requests)
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
assert scheduler.running[i] == request
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_schedule_multimodal_requests():
|
|
|
|
|
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
|
|
|
|
|
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=10,
|
|
|
|
|
num_tokens=200,
|
|
|
|
|
mm_positions=mm_positions,
|
|
|
|
|
)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
2025-02-04 02:43:58 -08:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
|
|
|
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
|
|
|
|
assert len(output.scheduled_encoder_inputs) == 10
|
|
|
|
|
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
|
|
|
|
assert len(encoder_input) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_schedule_partial_requests():
|
|
|
|
|
"""Test scheduling behavior with partial requests.
|
|
|
|
|
|
|
|
|
|
This test verifies that:
|
|
|
|
|
1. The scheduler can handle multiple partial requests in a single step when
|
|
|
|
|
constrained by encoder budget.
|
|
|
|
|
2. A request in RUNNING state may be unscheduled in subsequent steps if
|
|
|
|
|
there is insufficient encoder budget.
|
|
|
|
|
"""
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
|
)
|
|
|
|
|
mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)]
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=3,
|
|
|
|
|
num_tokens=800,
|
|
|
|
|
mm_positions=mm_positions,
|
|
|
|
|
)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 3
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
2025-02-04 02:43:58 -08:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
|
|
|
|
|
assert scheduler.max_num_encoder_input_tokens == 1024
|
|
|
|
|
# The first request is scheduled fully.
|
|
|
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
|
|
|
|
# The second request is scheduled partially.
|
|
|
|
|
# The <img> tokens are not scheduled because of the encoder budget.
|
|
|
|
|
assert output.num_scheduled_tokens[requests[1].request_id] == 100
|
|
|
|
|
# The third request is also scheduled partially.
|
|
|
|
|
# The <img> tokens are not scheduled because of the encoder budget.
|
|
|
|
|
assert output.num_scheduled_tokens[requests[2].request_id] == 100
|
|
|
|
|
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request.request_id for request in requests],
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-03-26 21:54:36 -07:00
|
|
|
# Only the first request has a sampled token id because
|
|
|
|
|
# the rest requests are still being prefilled.
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[0], [], []],
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-02-04 02:43:58 -08:00
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_runner_output)
|
|
|
|
|
|
|
|
|
|
# Schedule the next step.
|
|
|
|
|
# Only the first and second requests are scheduled.
|
|
|
|
|
# The third request is in the RUNNING state but not scheduled in this step
|
|
|
|
|
# because of the encoder budget.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.running) == 3
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 2
|
2025-02-04 02:43:58 -08:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
|
|
|
|
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
|
|
|
|
assert requests[2].request_id not in output.num_scheduled_tokens
|
2025-02-15 03:59:01 -08:00
|
|
|
|
|
|
|
|
|
2025-04-08 00:24:07 -06:00
|
|
|
def test_no_mm_input_chunking():
|
|
|
|
|
# Disable multimodal input chunking.
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
|
disable_chunked_mm_input=True,
|
2025-04-21 12:38:50 -07:00
|
|
|
max_model_len=2048,
|
2025-04-08 00:24:07 -06:00
|
|
|
)
|
|
|
|
|
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=1, num_tokens=1200, mm_positions=mm_positions
|
|
|
|
|
)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
2025-04-08 00:24:07 -06:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
# We want to only see the 400 text tokens at the start scheduled
|
|
|
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
|
|
|
|
|
|
|
|
|
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request.request_id for request in requests],
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
2025-04-08 00:24:07 -06:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-08 00:24:07 -06:00
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_runner_output)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
2025-04-08 00:24:07 -06:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
|
|
|
|
|
2025-04-11 02:09:52 -06:00
|
|
|
# Test that we fail if we disable chunked mm input and use too small
|
|
|
|
|
# of a max_num_batched_tokens for the mm input.
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
_ = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
max_num_batched_tokens=100,
|
|
|
|
|
disable_chunked_mm_input=True,
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-08 00:24:07 -06:00
|
|
|
|
2025-03-25 14:22:26 -07:00
|
|
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
2025-03-26 21:54:36 -07:00
|
|
|
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
2025-03-25 14:22:26 -07:00
|
|
|
"""Test scheduling behavior with concurrent partial requests.
|
|
|
|
|
|
|
|
|
|
This test verifies that: there are multiple long prefill requests in the
|
|
|
|
|
RUNNING state, and we can schedule them together.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="facebook/opt-125m",
|
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
|
long_prefill_token_threshold=400,
|
|
|
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
|
|
|
)
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=3,
|
|
|
|
|
num_tokens=800,
|
|
|
|
|
)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 3
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
2025-03-25 14:22:26 -07:00
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
|
|
|
|
|
# The first request is scheduled partially - 400.
|
|
|
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
|
|
|
|
# The second request is scheduled partially - 400.
|
|
|
|
|
assert output.num_scheduled_tokens[requests[1].request_id] == 400
|
|
|
|
|
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
|
|
|
|
|
assert output.num_scheduled_tokens[requests[2].request_id] == 224
|
|
|
|
|
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request.request_id for request in requests],
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
2025-03-25 14:22:26 -07:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-03-25 14:22:26 -07:00
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_runner_output)
|
|
|
|
|
|
|
|
|
|
# Schedule the next step. All three requests are running.
|
|
|
|
|
# Processed the remaining prefills of the first and second requests.
|
|
|
|
|
output1 = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.running) == 3
|
|
|
|
|
assert len(output1.scheduled_new_reqs) == 0
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output1.scheduled_cached_reqs.num_reqs == 3
|
2025-03-25 14:22:26 -07:00
|
|
|
assert len(output1.finished_req_ids) == 0
|
|
|
|
|
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
|
|
|
|
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
|
|
|
|
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
|
|
|
|
|
|
|
|
|
|
# Schedule the third step. All three requests are running.
|
|
|
|
|
# First and second requests are in the decode stage.
|
|
|
|
|
# All the remaining tokens in the third request are processed.
|
2025-03-26 21:54:36 -07:00
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request.request_id for request in requests],
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
2025-03-26 21:54:36 -07:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-03-26 21:54:36 -07:00
|
|
|
)
|
2025-03-25 14:22:26 -07:00
|
|
|
scheduler.update_from_output(output1, model_runner_output)
|
|
|
|
|
output2 = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.running) == 3
|
|
|
|
|
assert len(output2.scheduled_new_reqs) == 0
|
2025-06-30 09:07:50 -07:00
|
|
|
assert output2.scheduled_cached_reqs.num_reqs == 3
|
2025-03-25 14:22:26 -07:00
|
|
|
assert len(output2.finished_req_ids) == 0
|
|
|
|
|
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
|
|
|
|
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
|
|
|
|
assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224
|
|
|
|
|
|
|
|
|
|
|
2025-02-15 18:05:11 -08:00
|
|
|
def test_stop_via_update_from_output():
|
|
|
|
|
"""Test stopping behavior through update_from_output"""
|
2025-04-24 16:57:40 +01:00
|
|
|
scheduler = create_scheduler(num_speculative_tokens=1)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
# Test case 1: Stop on EOS token
|
|
|
|
|
requests = create_requests(num_requests=2, max_tokens=10)
|
|
|
|
|
for req in requests:
|
|
|
|
|
req.num_computed_tokens = req.num_tokens
|
|
|
|
|
scheduler.requests[req.request_id] = req
|
|
|
|
|
scheduler.running.append(req)
|
2025-07-12 15:33:13 -07:00
|
|
|
req.status = RequestStatus.RUNNING
|
2025-02-15 18:05:11 -08:00
|
|
|
|
2025-06-30 09:07:50 -07:00
|
|
|
scheduler_output = SchedulerOutput(
|
|
|
|
|
scheduled_new_reqs=[],
|
|
|
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
|
|
|
num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2},
|
|
|
|
|
total_num_scheduled_tokens=3,
|
|
|
|
|
scheduled_encoder_inputs={},
|
|
|
|
|
scheduled_spec_decode_tokens={
|
|
|
|
|
requests[0].request_id: [],
|
|
|
|
|
requests[1].request_id: [10],
|
|
|
|
|
},
|
2025-10-14 16:27:44 -07:00
|
|
|
num_common_prefix_blocks=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
finished_req_ids=set(),
|
2025-08-25 15:41:17 +08:00
|
|
|
free_encoder_mm_hashes=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
|
|
|
sampled_token_ids=[
|
2025-11-20 21:27:45 -08:00
|
|
|
[EOS_TOKEN_ID],
|
|
|
|
|
[10, 11],
|
2025-02-15 18:05:11 -08:00
|
|
|
], # First request hits EOS, second continues
|
|
|
|
|
logprobs=None,
|
2025-06-19 01:36:33 -03:00
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
|
|
|
|
|
|
# Verify first request stopped, second continues
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert scheduler.running[0].request_id == requests[1].request_id
|
|
|
|
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
|
|
|
|
assert requests[0].request_id in scheduler.finished_req_ids
|
|
|
|
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
|
|
|
|
assert list(requests[1].output_token_ids) == [10, 11]
|
|
|
|
|
|
|
|
|
|
# Test case 2: Stop on custom stop token
|
2025-04-24 16:57:40 +01:00
|
|
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
2025-02-15 18:05:11 -08:00
|
|
|
requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43])
|
|
|
|
|
for req in requests:
|
|
|
|
|
req.num_computed_tokens = req.num_tokens
|
|
|
|
|
scheduler.requests[req.request_id] = req
|
|
|
|
|
scheduler.running.append(req)
|
2025-07-12 15:33:13 -07:00
|
|
|
req.status = RequestStatus.RUNNING
|
2025-02-15 18:05:11 -08:00
|
|
|
|
2025-06-30 09:07:50 -07:00
|
|
|
scheduler_output = SchedulerOutput(
|
|
|
|
|
scheduled_new_reqs=[],
|
|
|
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
|
|
|
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2},
|
|
|
|
|
total_num_scheduled_tokens=5,
|
|
|
|
|
scheduled_encoder_inputs={},
|
|
|
|
|
scheduled_spec_decode_tokens={
|
|
|
|
|
requests[0].request_id: [10, 42],
|
|
|
|
|
requests[1].request_id: [13],
|
|
|
|
|
},
|
2025-10-14 16:27:44 -07:00
|
|
|
num_common_prefix_blocks=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
finished_req_ids=set(),
|
2025-08-25 15:41:17 +08:00
|
|
|
free_encoder_mm_hashes=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
|
2025-02-15 18:05:11 -08:00
|
|
|
logprobs=None,
|
2025-06-19 01:36:33 -03:00
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
|
|
|
|
|
|
# Verify first request stopped on custom token
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert scheduler.running[0].request_id == requests[1].request_id
|
|
|
|
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
|
|
|
|
assert requests[0].stop_reason == 42
|
|
|
|
|
assert requests[0].request_id in scheduler.finished_req_ids
|
|
|
|
|
assert list(requests[0].output_token_ids) == [10, 42]
|
|
|
|
|
assert list(requests[1].output_token_ids) == [13, 14]
|
|
|
|
|
|
|
|
|
|
# Test case 3: Stop on max tokens
|
2025-04-24 16:57:40 +01:00
|
|
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
2025-02-15 18:05:11 -08:00
|
|
|
requests = create_requests(num_requests=2, max_tokens=2)
|
|
|
|
|
for req in requests:
|
|
|
|
|
req.num_computed_tokens = req.num_tokens
|
|
|
|
|
scheduler.requests[req.request_id] = req
|
|
|
|
|
scheduler.running.append(req)
|
2025-07-12 15:33:13 -07:00
|
|
|
req.status = RequestStatus.RUNNING
|
2025-02-15 18:05:11 -08:00
|
|
|
|
2025-06-30 09:07:50 -07:00
|
|
|
scheduler_output = SchedulerOutput(
|
|
|
|
|
scheduled_new_reqs=[],
|
|
|
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
|
|
|
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1},
|
|
|
|
|
total_num_scheduled_tokens=4,
|
|
|
|
|
scheduled_encoder_inputs={},
|
|
|
|
|
scheduled_spec_decode_tokens={
|
|
|
|
|
requests[0].request_id: [10, 11],
|
|
|
|
|
requests[1].request_id: [],
|
|
|
|
|
},
|
2025-10-14 16:27:44 -07:00
|
|
|
num_common_prefix_blocks=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
finished_req_ids=set(),
|
2025-08-25 15:41:17 +08:00
|
|
|
free_encoder_mm_hashes=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
|
2025-02-15 18:05:11 -08:00
|
|
|
logprobs=None,
|
2025-06-19 01:36:33 -03:00
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
|
|
|
|
|
|
# Verify first request stopped due to length
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert scheduler.running[0].request_id == requests[1].request_id
|
|
|
|
|
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
|
|
|
|
assert requests[0].request_id in scheduler.finished_req_ids
|
|
|
|
|
assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens
|
|
|
|
|
assert list(requests[1].output_token_ids) == [13]
|
|
|
|
|
|
|
|
|
|
# Test case 4: Ignore EOS flag
|
2025-04-24 16:57:40 +01:00
|
|
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
2025-02-15 18:05:11 -08:00
|
|
|
requests = create_requests(num_requests=1, max_tokens=10)
|
|
|
|
|
requests[0].sampling_params.ignore_eos = True
|
|
|
|
|
requests[0].num_computed_tokens = requests[0].num_tokens
|
|
|
|
|
scheduler.requests[requests[0].request_id] = requests[0]
|
|
|
|
|
scheduler.running.append(requests[0])
|
|
|
|
|
|
|
|
|
|
scheduler_output = SchedulerOutput(
|
|
|
|
|
scheduled_new_reqs=[],
|
2025-06-30 09:07:50 -07:00
|
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
2025-02-15 18:05:11 -08:00
|
|
|
num_scheduled_tokens={requests[0].request_id: 3},
|
|
|
|
|
total_num_scheduled_tokens=3,
|
|
|
|
|
scheduled_encoder_inputs={},
|
|
|
|
|
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
|
2025-10-14 16:27:44 -07:00
|
|
|
num_common_prefix_blocks=[],
|
2025-02-15 18:05:11 -08:00
|
|
|
finished_req_ids=set(),
|
2025-08-25 15:41:17 +08:00
|
|
|
free_encoder_mm_hashes=[],
|
2025-03-07 10:19:11 -05:00
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[requests[0].request_id],
|
|
|
|
|
req_id_to_index={requests[0].request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
2025-02-15 18:05:11 -08:00
|
|
|
logprobs=None,
|
2025-06-19 01:36:33 -03:00
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
|
|
|
|
|
|
# Verify request continues past EOS
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert not requests[0].is_finished()
|
|
|
|
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
|
|
|
|
|
|
|
|
|
|
2025-10-08 14:08:24 -07:00
|
|
|
def test_check_stop_min_tokens():
|
|
|
|
|
"""Test that requests don't stop when min_tokens requirement isn't met."""
|
|
|
|
|
from vllm.v1.core.sched.utils import check_stop
|
|
|
|
|
|
|
|
|
|
# Test case 1: num_output_tokens < min_tokens
|
|
|
|
|
# Should return False (don't stop)
|
|
|
|
|
sampling_params = SamplingParams(
|
|
|
|
|
ignore_eos=False,
|
|
|
|
|
max_tokens=20,
|
|
|
|
|
min_tokens=5,
|
|
|
|
|
)
|
|
|
|
|
request = Request(
|
|
|
|
|
request_id="0",
|
|
|
|
|
prompt_token_ids=[0, 1, 2],
|
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
pooling_params=None,
|
|
|
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
|
|
|
)
|
|
|
|
|
# Simulate having generated 3 output tokens (less than min_tokens=5)
|
|
|
|
|
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
|
|
|
|
|
|
|
|
|
|
result = check_stop(request, max_model_len=100)
|
|
|
|
|
assert result is False, "Should not stop when num_output_tokens<min_tokens"
|
|
|
|
|
|
|
|
|
|
# Test case 2: num_output_tokens >= min_tokens
|
|
|
|
|
# Should follow normal stopping logic (stop on EOS)
|
|
|
|
|
request.append_output_token_ids(
|
|
|
|
|
[
|
|
|
|
|
10,
|
|
|
|
|
11,
|
|
|
|
|
12,
|
|
|
|
|
13,
|
|
|
|
|
14,
|
|
|
|
|
EOS_TOKEN_ID,
|
|
|
|
|
]
|
|
|
|
|
) # 6 tokens > min_tokens
|
|
|
|
|
|
|
|
|
|
result = check_stop(request, max_model_len=100)
|
|
|
|
|
assert result is True, "Should stop on EOS when min_tokens met"
|
|
|
|
|
assert request.status == RequestStatus.FINISHED_STOPPED
|
|
|
|
|
|
|
|
|
|
# Test case 3: min_tokens = 0, should follow normal stopping logic
|
|
|
|
|
sampling_params_no_min = SamplingParams(
|
|
|
|
|
ignore_eos=False,
|
|
|
|
|
max_tokens=20,
|
|
|
|
|
min_tokens=0,
|
|
|
|
|
)
|
|
|
|
|
request_no_min = Request(
|
|
|
|
|
request_id="1",
|
|
|
|
|
prompt_token_ids=[0, 1, 2],
|
|
|
|
|
sampling_params=sampling_params_no_min,
|
|
|
|
|
pooling_params=None,
|
|
|
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
|
|
|
)
|
|
|
|
|
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
|
|
|
|
|
|
|
|
|
|
result = check_stop(request_no_min, max_model_len=100)
|
|
|
|
|
assert result is True, "Should stop on EOS when min_tokens=0"
|
|
|
|
|
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
|
|
|
|
|
|
|
|
|
|
# Test case 4: min_tokens > 0 with stop token (not EOS)
|
|
|
|
|
sampling_params_stop = SamplingParams(
|
|
|
|
|
ignore_eos=False,
|
|
|
|
|
max_tokens=20,
|
|
|
|
|
min_tokens=5,
|
|
|
|
|
stop_token_ids=[42],
|
|
|
|
|
)
|
|
|
|
|
request_stop = Request(
|
|
|
|
|
request_id="2",
|
|
|
|
|
prompt_token_ids=[0, 1, 2],
|
|
|
|
|
sampling_params=sampling_params_stop,
|
|
|
|
|
pooling_params=None,
|
|
|
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
|
|
|
)
|
|
|
|
|
# Only 3 output tokens, less than min_tokens=5, but has stop token
|
|
|
|
|
request_stop.append_output_token_ids([10, 11, 42])
|
|
|
|
|
result = check_stop(request_stop, max_model_len=100)
|
|
|
|
|
assert result is False, "Should not stop when num_output_tokens<min_tokens"
|
|
|
|
|
|
|
|
|
|
# Test case 5: min_tokens met, should stop on stop token
|
|
|
|
|
request_stop.append_output_token_ids(
|
|
|
|
|
[10, 11, 12, 13, 14, 42]
|
|
|
|
|
) # 6 tokens >= min_tokens=5
|
|
|
|
|
|
|
|
|
|
result = check_stop(request_stop, max_model_len=100)
|
|
|
|
|
assert result is True, "Should stop on stop token when min_tokens met"
|
|
|
|
|
assert request_stop.status == RequestStatus.FINISHED_STOPPED
|
|
|
|
|
assert request_stop.stop_reason == 42
|
|
|
|
|
|
|
|
|
|
|
2025-03-07 20:48:12 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"enable_prefix_caching, prompt_logprobs",
|
|
|
|
|
[
|
2025-11-22 19:34:15 +08:00
|
|
|
(False, None),
|
2025-03-07 20:48:12 -05:00
|
|
|
(True, 5),
|
|
|
|
|
],
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-03-07 20:48:12 -05:00
|
|
|
def test_schedule_concurrent_batches(
|
2025-11-22 19:34:15 +08:00
|
|
|
enable_prefix_caching: bool, prompt_logprobs: int | None
|
2025-03-07 20:48:12 -05:00
|
|
|
):
|
2025-02-15 03:59:01 -08:00
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
|
max_num_seqs=2,
|
2025-03-07 20:48:12 -05:00
|
|
|
enable_prefix_caching=enable_prefix_caching,
|
2025-02-15 03:59:01 -08:00
|
|
|
)
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=2,
|
|
|
|
|
num_tokens=512,
|
2025-03-07 20:48:12 -05:00
|
|
|
prompt_logprobs=prompt_logprobs,
|
2025-02-15 03:59:01 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Schedule the first request.
|
|
|
|
|
scheduler.add_request(requests[0])
|
|
|
|
|
scheduler_output0 = scheduler.schedule()
|
|
|
|
|
assert len(scheduler_output0.scheduled_new_reqs) == 1
|
|
|
|
|
assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512
|
|
|
|
|
|
|
|
|
|
# The first request is still running, so only schedule the second request.
|
|
|
|
|
scheduler.add_request(requests[1])
|
|
|
|
|
scheduler_output1 = scheduler.schedule()
|
|
|
|
|
assert len(scheduler_output1.scheduled_new_reqs) == 1
|
|
|
|
|
assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512
|
|
|
|
|
|
|
|
|
|
# Model output of the first request.
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[requests[0].request_id],
|
|
|
|
|
req_id_to_index={requests[0].request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[0]],
|
2025-02-15 03:59:01 -08:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-02-15 03:59:01 -08:00
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
|
|
|
|
|
|
|
|
|
# Schedule the next step.
|
|
|
|
|
# The first request can be scheduled again while the second
|
|
|
|
|
# request is still running.
|
|
|
|
|
scheduler_output2 = scheduler.schedule()
|
|
|
|
|
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
|
|
|
|
|
|
|
|
|
|
# Model output of the second request.
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[requests[1].request_id],
|
|
|
|
|
req_id_to_index={requests[1].request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[0]],
|
2025-02-15 03:59:01 -08:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-02-15 03:59:01 -08:00
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
2025-04-01 18:45:04 +01:00
|
|
|
|
|
|
|
|
|
2025-11-25 16:06:09 +08:00
|
|
|
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
|
|
|
|
def test_schedule_order(enable_chunked_prefill: bool):
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
|
max_num_seqs=3,
|
|
|
|
|
enable_chunked_prefill=enable_chunked_prefill,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# long requests
|
|
|
|
|
requests = create_requests(num_requests=2, num_tokens=800)
|
|
|
|
|
# short requests
|
|
|
|
|
requests += create_requests(num_requests=2, num_tokens=10)
|
|
|
|
|
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
scheduler_output1 = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
if enable_chunked_prefill:
|
|
|
|
|
# When enable chunked prefill, long requests will be chunked.
|
|
|
|
|
assert len(scheduler_output1.scheduled_new_reqs) == 2
|
|
|
|
|
else:
|
|
|
|
|
# When disable chunked prefill, should not skip the long requests,
|
|
|
|
|
# and scheduling subsequent short requests in advance,
|
|
|
|
|
# even though there is still token budgets remaining.
|
|
|
|
|
assert len(scheduler_output1.scheduled_new_reqs) == 1
|
|
|
|
|
|
|
|
|
|
|
2025-07-12 15:33:13 -07:00
|
|
|
def test_preempt_during_execution():
|
|
|
|
|
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
|
|
|
|
|
# because block 0 is reserved as the null block.
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
max_num_batched_tokens=100,
|
|
|
|
|
block_size=16,
|
|
|
|
|
num_blocks=11,
|
|
|
|
|
enable_prefix_caching=False,
|
|
|
|
|
)
|
2025-08-16 02:52:52 +03:00
|
|
|
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
|
2025-07-12 15:33:13 -07:00
|
|
|
|
|
|
|
|
# Schedule the first request.
|
|
|
|
|
scheduler.add_request(requests[0])
|
|
|
|
|
scheduler_output0 = scheduler.schedule()
|
|
|
|
|
assert len(scheduler_output0.num_scheduled_tokens) == 1
|
|
|
|
|
assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5
|
|
|
|
|
|
|
|
|
|
# Schedule the second request while the first request is still running.
|
|
|
|
|
# This scenario can occur in certain cases, when max_concurrent_batches > 1
|
|
|
|
|
# (e.g., when pipeline parallelism is used).
|
|
|
|
|
scheduler.add_request(requests[1])
|
|
|
|
|
scheduler_output1 = scheduler.schedule()
|
|
|
|
|
assert len(scheduler_output1.num_scheduled_tokens) == 1
|
|
|
|
|
assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5
|
|
|
|
|
|
|
|
|
|
# Get the output of the first request.
|
|
|
|
|
model_runner_output0 = ModelRunnerOutput(
|
|
|
|
|
req_ids=[requests[0].request_id],
|
|
|
|
|
req_id_to_index={requests[0].request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[0]],
|
2025-07-12 15:33:13 -07:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(scheduler_output0, model_runner_output0)
|
|
|
|
|
|
|
|
|
|
# Schedule the first request again. This will cause the preemption
|
|
|
|
|
# of the second request because the KV cache is full.
|
|
|
|
|
_ = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert scheduler.running[0] == requests[0]
|
|
|
|
|
assert requests[1].status == RequestStatus.PREEMPTED
|
|
|
|
|
|
|
|
|
|
model_runner_output1 = ModelRunnerOutput(
|
|
|
|
|
req_ids=[requests[1].request_id],
|
|
|
|
|
req_id_to_index={requests[1].request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[42]],
|
2025-07-12 15:33:13 -07:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(scheduler_output1, model_runner_output1)
|
|
|
|
|
|
|
|
|
|
# The second request (that is preempted) should be updated with the
|
|
|
|
|
# sampled token id.
|
|
|
|
|
assert len(requests[1].output_token_ids) == 1
|
|
|
|
|
assert requests[1].output_token_ids[0] == 42
|
|
|
|
|
|
|
|
|
|
|
2025-12-01 18:25:05 -08:00
|
|
|
def test_scheduler_reset_prefix_cache():
|
|
|
|
|
scheduler = create_scheduler(enable_prefix_caching=True)
|
|
|
|
|
requests = create_requests(num_requests=10)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Initial scheduling, requests should be at the running state now
|
|
|
|
|
_ = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Verify requests moved from waiting to running
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == len(requests)
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
assert scheduler.running[i] == request
|
|
|
|
|
|
|
|
|
|
# Reset prefix cache should fail since there are still running requests
|
|
|
|
|
# and they are taking KV cache
|
|
|
|
|
assert not scheduler.reset_prefix_cache()
|
|
|
|
|
|
|
|
|
|
# Reset prefix cache with reset_running_requests=True. All running requests
|
|
|
|
|
# Should be pushed back to the waiting queue and kv cache should be freed
|
|
|
|
|
assert scheduler.reset_prefix_cache(reset_running_requests=True)
|
|
|
|
|
|
|
|
|
|
# Verify requests moved from running to waiting
|
|
|
|
|
assert len(scheduler.waiting) == len(requests)
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
assert scheduler.waiting[i] == request
|
|
|
|
|
|
|
|
|
|
|
2025-04-01 18:45:04 +01:00
|
|
|
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"spec_tokens,output_tokens,expected",
|
|
|
|
|
[
|
2025-11-20 21:27:45 -08:00
|
|
|
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
|
|
|
|
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
|
|
|
|
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
|
|
|
|
|
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
|
|
|
|
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
2025-04-01 18:45:04 +01:00
|
|
|
(
|
|
|
|
|
[[1, 2, 3], [4, 5, 6]],
|
2025-11-20 21:27:45 -08:00
|
|
|
[[1, 2, 7], [4, 8]],
|
2025-04-24 16:57:40 +01:00
|
|
|
(2, 6, 3, [2, 1, 0]),
|
|
|
|
|
), # multiple mismatches
|
2025-04-01 18:45:04 +01:00
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|
|
|
|
"""Test scheduling behavior with speculative decoding.
|
|
|
|
|
|
|
|
|
|
This test verifies that:
|
|
|
|
|
1. Speculated tokens get scheduled correctly
|
|
|
|
|
2. Spec decoding stats properly count number of draft and accepted tokens
|
|
|
|
|
"""
|
2025-04-24 16:57:40 +01:00
|
|
|
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
|
|
|
|
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
2025-04-01 18:45:04 +01:00
|
|
|
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
|
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
|
|
|
|
|
|
|
|
|
# Schedule a decode, which will also draft speculative tokens
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
|
|
|
assert output.total_num_scheduled_tokens == len(requests)
|
|
|
|
|
for i in range(len(requests)):
|
|
|
|
|
req_id = requests[i].request_id
|
|
|
|
|
assert output.num_scheduled_tokens[req_id] == 1
|
|
|
|
|
assert req_id not in output.scheduled_spec_decode_tokens
|
|
|
|
|
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[0] for _ in range(len(requests))],
|
2025-04-01 18:45:04 +01:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-01 18:45:04 +01:00
|
|
|
)
|
|
|
|
|
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
|
2025-08-18 17:20:38 -07:00
|
|
|
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
|
|
|
|
scheduler.update_draft_token_ids(draft_token_ids)
|
2025-04-01 18:45:04 +01:00
|
|
|
|
|
|
|
|
for i in range(len(requests)):
|
|
|
|
|
running_req = scheduler.running[i]
|
|
|
|
|
# The prompt token
|
|
|
|
|
assert running_req.num_computed_tokens == 1
|
|
|
|
|
# The prompt token and the sampled token
|
|
|
|
|
assert running_req.num_tokens == 2
|
|
|
|
|
# The prompt token, the sampled token, and the speculated tokens
|
|
|
|
|
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
|
|
|
|
|
|
|
|
|
# No draft or accepted tokens counted yet
|
2025-05-30 08:17:00 -07:00
|
|
|
assert not engine_core_outputs or (
|
|
|
|
|
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None
|
|
|
|
|
)
|
2025-04-01 18:45:04 +01:00
|
|
|
|
|
|
|
|
# Schedule the speculated tokens for validation
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
# The sampled token and speculated tokens
|
|
|
|
|
assert output.total_num_scheduled_tokens == len(requests) + sum(
|
|
|
|
|
len(ids) for ids in spec_tokens
|
|
|
|
|
)
|
|
|
|
|
for i in range(len(requests)):
|
|
|
|
|
req_id = requests[i].request_id
|
|
|
|
|
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
|
|
|
|
|
if spec_tokens[i]:
|
|
|
|
|
assert len(output.scheduled_spec_decode_tokens[req_id]) == len(
|
|
|
|
|
spec_tokens[i]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assert req_id not in output.scheduled_spec_decode_tokens
|
|
|
|
|
|
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
|
|
|
|
sampled_token_ids=output_tokens,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-01 18:45:04 +01:00
|
|
|
)
|
|
|
|
|
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
|
|
|
|
|
|
2025-05-30 08:17:00 -07:00
|
|
|
scheduler_stats = (
|
|
|
|
|
engine_core_outputs[0].scheduler_stats if engine_core_outputs else None
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-04-04 16:52:41 +01:00
|
|
|
if expected[0] == 0:
|
2025-10-10 19:21:25 +01:00
|
|
|
assert scheduler_stats is not None
|
2025-04-04 16:52:41 +01:00
|
|
|
assert scheduler_stats.spec_decoding_stats is None
|
|
|
|
|
else:
|
2025-10-10 19:21:25 +01:00
|
|
|
assert scheduler_stats is not None
|
2025-04-04 16:52:41 +01:00
|
|
|
assert scheduler_stats.spec_decoding_stats is not None
|
|
|
|
|
stats = scheduler_stats.spec_decoding_stats
|
2025-04-24 16:57:40 +01:00
|
|
|
assert stats.num_drafts == expected[0]
|
|
|
|
|
assert stats.num_draft_tokens == expected[1]
|
|
|
|
|
assert stats.num_accepted_tokens == expected[2]
|
|
|
|
|
assert stats.num_accepted_tokens_per_pos == expected[3]
|
2025-04-17 15:22:40 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assert_right_scheduler_output(
|
|
|
|
|
output: SchedulerOutput,
|
|
|
|
|
num_requests: int,
|
|
|
|
|
expected_num_scheduled_tokens: int,
|
|
|
|
|
):
|
|
|
|
|
"""Check if SchedulerOutput is correct after remote KV cache hit."""
|
|
|
|
|
|
|
|
|
|
# We should inject the kv_connector_metadata.
|
|
|
|
|
assert len(output.kv_connector_metadata.requests) == num_requests
|
|
|
|
|
|
|
|
|
|
# Only num_tokens - matched_num_new_tokens should be scheduled.
|
|
|
|
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
|
|
|
|
assert num_scheduled_tokens == expected_num_scheduled_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assert_right_kv_cache_manager(
|
|
|
|
|
scheduler: Scheduler,
|
2025-08-16 02:52:52 +03:00
|
|
|
requests: list[Request],
|
2025-04-17 15:22:40 -05:00
|
|
|
num_tokens: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
num_requests: int,
|
|
|
|
|
num_total_blocks: int,
|
|
|
|
|
):
|
|
|
|
|
"""Check whether KVCacheManager is correct after allocate."""
|
|
|
|
|
|
|
|
|
|
# Make sure the request stats are right.
|
2025-04-22 00:52:18 -07:00
|
|
|
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
|
2025-08-16 02:52:52 +03:00
|
|
|
for req in requests:
|
2025-06-06 11:47:09 +08:00
|
|
|
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
2025-10-05 15:06:22 +01:00
|
|
|
0
|
2025-08-16 02:52:52 +03:00
|
|
|
].req_to_blocks[req.request_id]
|
|
|
|
|
hashes = req.block_hashes
|
2025-06-06 11:47:09 +08:00
|
|
|
assert (
|
|
|
|
|
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
2025-10-05 15:06:22 +01:00
|
|
|
0
|
2025-08-16 02:52:52 +03:00
|
|
|
].num_cached_block[req.request_id]
|
|
|
|
|
== EXPECTED_TOTAL_BLOCKS
|
|
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
2025-04-22 00:52:18 -07:00
|
|
|
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
|
2025-04-17 15:22:40 -05:00
|
|
|
|
|
|
|
|
# Make sure we actually touched all the blocks.
|
2025-04-22 00:52:18 -07:00
|
|
|
BLOCKS_PER_REQ = num_tokens / block_size
|
2025-04-17 15:22:40 -05:00
|
|
|
assert (
|
|
|
|
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
|
|
|
|
|
== num_total_blocks - num_requests * BLOCKS_PER_REQ
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _step_until_done(
|
|
|
|
|
scheduler: Scheduler,
|
|
|
|
|
output: SchedulerOutput,
|
|
|
|
|
model_runner_output: ModelRunnerOutput,
|
|
|
|
|
):
|
|
|
|
|
"""Loop over schedule(), update_from_output() until finished."""
|
|
|
|
|
|
|
|
|
|
all_finished = False
|
|
|
|
|
_ = scheduler.update_from_output(output, model_runner_output)
|
|
|
|
|
while not all_finished:
|
|
|
|
|
# Schedule + a few iterations until stopping.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.running)
|
|
|
|
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
|
|
|
|
# We should be in the decode phase now.
|
|
|
|
|
assert num_scheduled_tokens == 1
|
2025-11-12 10:58:33 +08:00
|
|
|
if scheduler.connector is not None:
|
|
|
|
|
assert len(output.kv_connector_metadata.requests) == 0
|
|
|
|
|
if scheduler.ec_connector is not None:
|
|
|
|
|
assert len(output.ec_connector_metadata.mm_datas) == 0
|
2025-05-30 08:17:00 -07:00
|
|
|
ecos = scheduler.update_from_output(output, model_runner_output)[0]
|
2025-04-17 15:22:40 -05:00
|
|
|
all_done = True
|
|
|
|
|
for eco in ecos.outputs:
|
|
|
|
|
if eco.finish_reason is None:
|
|
|
|
|
all_done = False
|
|
|
|
|
all_finished = all_done
|
|
|
|
|
|
|
|
|
|
|
2025-11-13 23:30:59 +00:00
|
|
|
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
|
|
|
|
"""Cycle requests through a KV transfer cyle."""
|
|
|
|
|
|
|
|
|
|
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.waiting) == len(req_ids)
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
for req in scheduler.requests.values():
|
|
|
|
|
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
|
|
|
|
|
|
|
|
|
# No model execution yet
|
|
|
|
|
EMPTY_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=[],
|
|
|
|
|
req_id_to_index={},
|
|
|
|
|
sampled_token_ids=[],
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, EMPTY_OUTPUT)
|
|
|
|
|
|
|
|
|
|
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(scheduler.waiting) == len(req_ids)
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
|
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=[],
|
|
|
|
|
req_id_to_index={},
|
|
|
|
|
sampled_token_ids=[],
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
kv_connector_output=KVConnectorOutput(finished_recving=req_ids),
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
for req_id in req_ids:
|
|
|
|
|
assert req_id in scheduler.finished_recving_kv_req_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_async", [False, True])
|
|
|
|
|
def test_kv_connector_basic(is_async: bool):
|
2025-04-17 15:22:40 -05:00
|
|
|
"""
|
|
|
|
|
Test whether Scheduler with KVConnector schedules tokens, allocates
|
|
|
|
|
memory, and cleans up requests as expected under normal operation.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Setup Scheduler.
|
2025-11-13 23:30:59 +00:00
|
|
|
BLOCK_SIZE = 16
|
|
|
|
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
2025-04-17 15:22:40 -05:00
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
enable_prefix_caching=True,
|
2025-11-13 23:30:59 +00:00
|
|
|
use_kv_connector=mock_kv(
|
|
|
|
|
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
|
|
|
|
|
),
|
|
|
|
|
block_size=BLOCK_SIZE,
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
|
|
|
|
|
|
|
|
|
|
######################################################
|
|
|
|
|
# FIRST SET OF REQUESTS - External Hit Only
|
|
|
|
|
NUM_REQUESTS = 2
|
|
|
|
|
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
|
|
|
|
|
MAX_TOKENS = 3
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
2025-08-16 02:52:52 +03:00
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
|
|
|
|
|
2025-11-13 23:30:59 +00:00
|
|
|
if is_async:
|
|
|
|
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
|
|
|
|
|
2025-04-17 15:22:40 -05:00
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * len(req_ids),
|
2025-04-17 15:22:40 -05:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Ensure ScheduleOutput is correct.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output=output,
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
# Just the incremental tokens should be scheduled.
|
|
|
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Ensure KVCacheManager is correct.
|
2025-08-16 02:52:52 +03:00
|
|
|
_assert_right_kv_cache_manager(
|
|
|
|
|
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Continue Generation until done.
|
|
|
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
_ = scheduler.schedule()
|
|
|
|
|
# Confirm we clean up the memory properly.
|
|
|
|
|
assert (
|
|
|
|
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
|
|
|
|
|
######################################################
|
|
|
|
|
# SECOND SET OF REQUESTS - Local And External Hit
|
|
|
|
|
NUM_TOKENS_PREFIX = NUM_TOKENS
|
|
|
|
|
# We will get a local prefix cache hit for the first
|
|
|
|
|
# NUM_TOKENS_PREFIX tokens since they are used above.
|
|
|
|
|
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
2025-08-16 02:52:52 +03:00
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
|
|
|
|
|
2025-11-13 23:30:59 +00:00
|
|
|
if is_async:
|
|
|
|
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
|
|
|
|
|
2025-04-17 15:22:40 -05:00
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * len(req_ids),
|
2025-04-17 15:22:40 -05:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# We should get a local cache hit of NUM_TOKENS_PREFIX and
|
|
|
|
|
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output=output,
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
# Just the incremental tokens after local + remote cache hit.
|
|
|
|
|
expected_num_scheduled_tokens=(
|
|
|
|
|
NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS
|
2025-10-05 15:06:22 +01:00
|
|
|
),
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Ensure KVCacheManager is correct.
|
2025-08-16 02:52:52 +03:00
|
|
|
_assert_right_kv_cache_manager(
|
|
|
|
|
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Continue Generation until done.
|
|
|
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
_ = scheduler.schedule()
|
|
|
|
|
# Confirm we clean up the memory properly.
|
|
|
|
|
assert (
|
|
|
|
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
|
|
|
|
|
|
2025-11-21 22:45:00 +00:00
|
|
|
@pytest.mark.parametrize("is_async", [False, True])
|
|
|
|
|
def test_external_prefix_cache_metrics(is_async: bool):
|
2025-10-23 13:21:08 +03:00
|
|
|
"""
|
|
|
|
|
Verify connector prefix cache metrics are updated
|
|
|
|
|
correctly when the scheduler processes requests with KV connector hits.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Setup Scheduler.
|
2025-11-13 23:30:59 +00:00
|
|
|
NUM_MATCHED_NEW_TOKENS = 4
|
2025-10-23 13:21:08 +03:00
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
enable_prefix_caching=False,
|
2025-11-21 22:45:00 +00:00
|
|
|
use_kv_connector=mock_kv(
|
|
|
|
|
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
|
|
|
|
|
),
|
2025-10-23 13:21:08 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# --- Prepare simple requests ---
|
|
|
|
|
NUM_REQUESTS = 2
|
|
|
|
|
NUM_TOKENS = 8
|
|
|
|
|
MAX_TOKENS = 2
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
|
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
)
|
2025-11-21 22:45:00 +00:00
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
2025-10-23 13:21:08 +03:00
|
|
|
|
2025-11-21 22:45:00 +00:00
|
|
|
if is_async:
|
|
|
|
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
2025-10-23 13:21:08 +03:00
|
|
|
|
|
|
|
|
# --- Trigger scheduling and simulate model output ---
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=[r.request_id for r in requests],
|
|
|
|
|
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * NUM_REQUESTS,
|
2025-10-23 13:21:08 +03:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update scheduler stats
|
|
|
|
|
ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
|
|
|
|
|
# --- Assertions ---
|
|
|
|
|
assert ecos is not None and len(ecos) > 0
|
|
|
|
|
assert ecos[0].scheduler_stats is not None
|
|
|
|
|
|
|
|
|
|
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
|
|
|
|
|
assert external_stats is not None
|
|
|
|
|
|
|
|
|
|
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
|
|
|
|
|
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
|
|
|
|
|
assert external_stats.requests == NUM_REQUESTS
|
|
|
|
|
assert external_stats.preempted_requests == 0
|
|
|
|
|
|
|
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
|
|
|
|
|
)
|
|
|
|
|
def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
|
2025-04-17 15:22:40 -05:00
|
|
|
"""
|
|
|
|
|
Test whether scheduler with KVConnector is able to handle
|
|
|
|
|
unable to allocate (run out of blocks in allocate_slots().
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Setup Scheduler With Mock External Cache Hit.
|
|
|
|
|
BLOCK_SIZE = 4
|
|
|
|
|
NUM_BLOCKS = 10
|
2025-11-13 23:30:59 +00:00
|
|
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
2025-04-17 15:22:40 -05:00
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
enable_prefix_caching=True,
|
2025-11-13 23:30:59 +00:00
|
|
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
2025-04-17 15:22:40 -05:00
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
num_blocks=NUM_BLOCKS,
|
2025-11-12 10:58:33 +08:00
|
|
|
# encoder connector should not affect test results
|
|
|
|
|
use_ec_connector=use_ec_connector,
|
|
|
|
|
ec_role=ec_role,
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create two requests. The second request will not be able to
|
|
|
|
|
# allocate slots because it will not have enough blocks.
|
|
|
|
|
NUM_REQUESTS = 2
|
|
|
|
|
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
|
|
|
|
|
MAX_TOKENS = 2
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
2025-08-16 02:52:52 +03:00
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
|
|
|
|
|
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * len(req_ids),
|
2025-04-17 15:22:40 -05:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Just one request should be running.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
num_requests=1,
|
|
|
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
|
|
|
|
|
# All memory should be freed, with one request waiting.
|
|
|
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
|
|
|
|
|
# Just one request should be running.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
num_requests=1,
|
|
|
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
|
|
|
|
# All memory should be freed, with no requests waiting / running.
|
|
|
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
|
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
|
|
|
|
|
)
|
|
|
|
|
def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
|
2025-04-17 15:22:40 -05:00
|
|
|
"""
|
|
|
|
|
Test whether scheduler with KVConnector is able to handle
|
|
|
|
|
unable to allocate (run out of blocks in allocate_slots().
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Setup Scheduler With Mock External Cache Hit.
|
|
|
|
|
BLOCK_SIZE = 2
|
|
|
|
|
# NOTE: there is 1 null block, so this is 6 blocks.
|
|
|
|
|
NUM_BLOCKS = 7
|
2025-11-13 23:30:59 +00:00
|
|
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
2025-04-17 15:22:40 -05:00
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
enable_prefix_caching=True,
|
2025-11-13 23:30:59 +00:00
|
|
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
2025-04-17 15:22:40 -05:00
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
num_blocks=NUM_BLOCKS,
|
2025-11-12 10:58:33 +08:00
|
|
|
# encoder connector should not affect test results
|
|
|
|
|
use_ec_connector=use_ec_connector,
|
|
|
|
|
ec_role=ec_role,
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create two requests.
|
|
|
|
|
# Both can be scheduled at first, but the second request
|
|
|
|
|
# will be preempted and re-scheduled.
|
|
|
|
|
NUM_REQUESTS = 2
|
|
|
|
|
NUM_TOKENS = BLOCK_SIZE * 2 + 1
|
|
|
|
|
MAX_TOKENS = BLOCK_SIZE * 2
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
2025-08-16 02:52:52 +03:00
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
)
|
2025-04-17 15:22:40 -05:00
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
|
|
|
|
|
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * len(req_ids),
|
2025-04-17 15:22:40 -05:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-04-17 15:22:40 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# All can be scheduled - 1st token.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
# 2 remote kv cache hits.
|
|
|
|
|
num_requests=2,
|
|
|
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 2
|
|
|
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
|
|
|
|
|
# All can be scheduled - 2nd token.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
# no connector_metadata
|
|
|
|
|
num_requests=0,
|
|
|
|
|
expected_num_scheduled_tokens=1,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 2
|
|
|
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
|
|
|
|
|
# This will generate a new block and cause a preemption - 3rd token.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
# no connector_metadata
|
|
|
|
|
num_requests=0,
|
|
|
|
|
expected_num_scheduled_tokens=1,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
|
|
|
|
|
# Only 1 can be scheduled - 4th (and last token).
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
# no connector_metadata
|
|
|
|
|
num_requests=0,
|
|
|
|
|
expected_num_scheduled_tokens=1,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
# All memory should be freed since nothing is running.
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
|
|
|
|
|
|
# Restarts the preempted request - generate 3rd token.
|
|
|
|
|
# This will have a local and remote cache hit.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
# 1 remote kv_cache hit!
|
|
|
|
|
num_requests=1,
|
|
|
|
|
# Only 1 block was preempted and there is a single
|
|
|
|
|
# remote hit. So only single new token is scheduled.
|
|
|
|
|
expected_num_scheduled_tokens=1,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
|
|
|
|
# Only 1 can be scheduled - 4th (and last token).
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
_assert_right_scheduler_output(
|
|
|
|
|
output,
|
|
|
|
|
# no connector_metadata
|
|
|
|
|
num_requests=0,
|
|
|
|
|
expected_num_scheduled_tokens=1,
|
|
|
|
|
)
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
# All memory should be freed since nothing is running.
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
2025-05-02 04:07:03 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_output(scheduler: Scheduler):
|
|
|
|
|
return ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in scheduler.running],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * len(scheduler.running),
|
2025-05-02 04:07:03 -04:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
2025-06-19 01:36:33 -03:00
|
|
|
pooler_output=[],
|
2025-05-02 04:07:03 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_scheduler_empty(scheduler: Scheduler):
|
|
|
|
|
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
|
|
|
|
# Scheduler Metadata.
|
|
|
|
|
assert len(scheduler.requests) == 0
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(scheduler.finished_req_ids) == 0
|
|
|
|
|
|
|
|
|
|
# EncoderCacheManager.
|
|
|
|
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
|
|
|
|
assert len(scheduler.encoder_cache_manager.cached) == 0
|
|
|
|
|
|
|
|
|
|
# KVCache Manager.
|
2025-06-06 11:47:09 +08:00
|
|
|
assert (
|
|
|
|
|
len(
|
|
|
|
|
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
|
|
|
|
|
)
|
|
|
|
|
== 0
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-06-06 11:47:09 +08:00
|
|
|
assert (
|
|
|
|
|
len(
|
|
|
|
|
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
|
|
|
0
|
|
|
|
|
].num_cached_block
|
|
|
|
|
)
|
|
|
|
|
== 0
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-05-02 04:07:03 -04:00
|
|
|
num_free_blocks = (
|
|
|
|
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
|
|
|
|
)
|
|
|
|
|
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
|
|
|
|
|
|
|
|
|
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
|
|
|
|
# value, etc will remain since we lazily evict for prefix cache.
|
|
|
|
|
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
|
|
|
|
assert block.ref_cnt == 0
|
|
|
|
|
# assert block._block_hash is None
|
|
|
|
|
# assert (
|
|
|
|
|
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
|
|
|
|
# ) == 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_memory_leak():
|
|
|
|
|
"""Test that we do not have a memory leak."""
|
|
|
|
|
|
|
|
|
|
scheduler = create_scheduler(enable_prefix_caching=True)
|
|
|
|
|
|
|
|
|
|
NUM_REQUESTS = 5
|
|
|
|
|
NUM_TOKENS = 10
|
|
|
|
|
MAX_TOKENS = 10
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add each request.
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
scheduler_output = scheduler.schedule()
|
|
|
|
|
model_runner_output = make_output(scheduler)
|
|
|
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
|
|
|
|
|
|
# Iterate until done.
|
|
|
|
|
while True:
|
|
|
|
|
scheduler_output = scheduler.schedule()
|
|
|
|
|
if len(scheduler.running) == 0:
|
|
|
|
|
break
|
|
|
|
|
model_runner_output = make_output(scheduler)
|
|
|
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
|
|
|
|
|
|
# Confirm no memory leak.
|
|
|
|
|
assert_scheduler_empty(scheduler)
|
2025-06-23 06:18:08 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_scheduler_with_priority(
|
|
|
|
|
model: str = "facebook/opt-125m",
|
|
|
|
|
max_num_seqs: int = 16,
|
|
|
|
|
max_num_batched_tokens: int = 8192,
|
2025-11-22 19:34:15 +08:00
|
|
|
enable_prefix_caching: bool = False,
|
2025-06-23 06:18:08 +03:00
|
|
|
long_prefill_token_threshold: int = 0,
|
|
|
|
|
disable_chunked_mm_input: bool = False,
|
|
|
|
|
use_kv_connector: bool = False,
|
|
|
|
|
num_blocks: int = 10000,
|
|
|
|
|
block_size: int = 16,
|
|
|
|
|
max_model_len: int | None = None,
|
|
|
|
|
num_speculative_tokens: int | None = None,
|
2025-11-12 10:58:33 +08:00
|
|
|
use_ec_connector: bool = False,
|
|
|
|
|
ec_role: str | None = None,
|
2025-06-23 06:18:08 +03:00
|
|
|
) -> Scheduler:
|
|
|
|
|
"""Create scheduler with priority policy enabled.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model: model under test
|
|
|
|
|
max_num_seqs: max sequences to schedule
|
|
|
|
|
max_num_batch_tokens: max num tokens to batch
|
|
|
|
|
enable_prefix_caching: optionally force APC config
|
|
|
|
|
(True/False) or use default
|
2025-11-22 19:34:15 +08:00
|
|
|
(False)
|
2025-06-23 06:18:08 +03:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
{class}`Scheduler` instance with priority scheduling
|
2025-10-05 15:06:22 +01:00
|
|
|
"""
|
2025-12-02 12:16:37 +00:00
|
|
|
model_config = ModelConfig(
|
|
|
|
|
model=model,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
dtype="float16",
|
|
|
|
|
seed=42,
|
|
|
|
|
)
|
2025-06-23 06:18:08 +03:00
|
|
|
if max_model_len is None:
|
|
|
|
|
max_model_len = max_num_batched_tokens
|
|
|
|
|
scheduler_config = SchedulerConfig(
|
|
|
|
|
max_num_seqs=max_num_seqs,
|
|
|
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
|
|
|
max_model_len=max_model_len,
|
|
|
|
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
|
|
|
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
|
|
|
|
enable_chunked_prefill=True,
|
2025-12-02 12:16:37 +00:00
|
|
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
2025-06-23 06:18:08 +03:00
|
|
|
policy="priority", # Enable priority scheduling
|
|
|
|
|
)
|
|
|
|
|
# Cache config, optionally force APC
|
|
|
|
|
cache_config = CacheConfig(
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
gpu_memory_utilization=0.9,
|
|
|
|
|
swap_space=0,
|
|
|
|
|
cache_dtype="auto",
|
2025-11-22 19:34:15 +08:00
|
|
|
enable_prefix_caching=enable_prefix_caching,
|
2025-06-23 06:18:08 +03:00
|
|
|
)
|
|
|
|
|
kv_transfer_config = (
|
|
|
|
|
KVTransferConfig(
|
|
|
|
|
kv_connector="SharedStorageConnector",
|
|
|
|
|
kv_role="kv_both",
|
|
|
|
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
|
|
|
|
)
|
|
|
|
|
if use_kv_connector
|
|
|
|
|
else None
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-06-23 06:18:08 +03:00
|
|
|
|
|
|
|
|
speculative_config: SpeculativeConfig | None = None
|
|
|
|
|
if num_speculative_tokens is not None:
|
|
|
|
|
speculative_config = SpeculativeConfig(
|
|
|
|
|
model="ngram", num_speculative_tokens=num_speculative_tokens
|
|
|
|
|
)
|
|
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
ec_transfer_config = (
|
|
|
|
|
ECTransferConfig(
|
|
|
|
|
ec_connector="ECSharedStorageConnector",
|
|
|
|
|
ec_role=ec_role,
|
|
|
|
|
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
|
|
|
|
|
)
|
|
|
|
|
if use_ec_connector
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
|
2025-06-23 06:18:08 +03:00
|
|
|
vllm_config = VllmConfig(
|
|
|
|
|
scheduler_config=scheduler_config,
|
|
|
|
|
model_config=model_config,
|
|
|
|
|
cache_config=cache_config,
|
|
|
|
|
kv_transfer_config=kv_transfer_config,
|
|
|
|
|
speculative_config=speculative_config,
|
2025-11-12 10:58:33 +08:00
|
|
|
ec_transfer_config=ec_transfer_config,
|
2025-06-23 06:18:08 +03:00
|
|
|
)
|
|
|
|
|
kv_cache_config = KVCacheConfig(
|
|
|
|
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
|
|
|
|
kv_cache_tensors=[],
|
|
|
|
|
kv_cache_groups=[
|
|
|
|
|
KVCacheGroupSpec(
|
|
|
|
|
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
cache_config.num_gpu_blocks = num_blocks
|
|
|
|
|
return Scheduler(
|
|
|
|
|
vllm_config=vllm_config,
|
|
|
|
|
kv_cache_config=kv_cache_config,
|
|
|
|
|
log_stats=True,
|
|
|
|
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
2025-10-10 18:02:49 +08:00
|
|
|
block_size=block_size,
|
2025-06-23 06:18:08 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
_none_hash_initialized = False
|
|
|
|
|
|
|
|
|
|
|
2025-06-23 06:18:08 +03:00
|
|
|
def create_requests_with_priority(
|
|
|
|
|
num_requests: int,
|
|
|
|
|
priorities: list[int],
|
|
|
|
|
arrival_times: list[float] | None = None,
|
|
|
|
|
num_tokens: int = 10,
|
2025-11-12 10:58:33 +08:00
|
|
|
mm_hashes_list: list[list[str]] | None = None,
|
2025-08-13 22:18:07 +08:00
|
|
|
mm_positions: list[list[PlaceholderRange]] | None = None,
|
2025-06-23 06:18:08 +03:00
|
|
|
max_tokens: int = 16,
|
|
|
|
|
stop_token_ids: list[int] | None = None,
|
2025-08-27 17:44:52 -07:00
|
|
|
prompt_logprobs: int | None = None,
|
|
|
|
|
starting_idx: int = 0,
|
2025-11-12 10:58:33 +08:00
|
|
|
same_prompt: bool = False,
|
|
|
|
|
block_size: int = 16,
|
|
|
|
|
req_ids: list[str] | None = None,
|
2025-08-27 17:44:52 -07:00
|
|
|
):
|
2025-06-23 06:18:08 +03:00
|
|
|
"""Create requests with specified priorities and arrival times."""
|
|
|
|
|
assert len(priorities) == num_requests
|
|
|
|
|
if arrival_times is not None:
|
|
|
|
|
assert len(arrival_times) == num_requests
|
|
|
|
|
else:
|
|
|
|
|
arrival_times = [float(i) for i in range(num_requests)]
|
|
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
global _none_hash_initialized
|
|
|
|
|
if not _none_hash_initialized:
|
|
|
|
|
init_none_hash(sha256)
|
|
|
|
|
_none_hash_initialized = True
|
|
|
|
|
|
|
|
|
|
block_hasher = get_request_block_hasher(block_size, sha256)
|
2025-06-23 06:18:08 +03:00
|
|
|
sampling_params = SamplingParams(
|
|
|
|
|
ignore_eos=False,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
stop_token_ids=stop_token_ids,
|
|
|
|
|
prompt_logprobs=prompt_logprobs,
|
|
|
|
|
)
|
|
|
|
|
requests = []
|
2025-11-12 10:58:33 +08:00
|
|
|
|
|
|
|
|
if mm_hashes_list is not None:
|
|
|
|
|
# NOTE: allow manual input; some mm items can have the same identifier
|
|
|
|
|
# no. of mm_hashes and mm_positions for each request should be identical
|
|
|
|
|
assert mm_positions is not None, (
|
|
|
|
|
"mm_positions must be provided when mm_hashes_list is provided"
|
|
|
|
|
)
|
|
|
|
|
assert len(mm_hashes_list) == len(mm_positions) == num_requests
|
|
|
|
|
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
|
|
|
|
|
|
|
|
|
|
# Since same identifier would imply they are identical encoder output
|
|
|
|
|
# Verify mm items with identical identifier are having mm_position.length
|
|
|
|
|
seen_hashes: dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
if req_ids:
|
|
|
|
|
assert len(req_ids) == num_requests
|
|
|
|
|
else:
|
|
|
|
|
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
|
|
|
|
|
|
2025-06-23 06:18:08 +03:00
|
|
|
for i in range(num_requests):
|
2025-08-29 03:36:57 -07:00
|
|
|
mm_features = []
|
2025-11-12 10:58:33 +08:00
|
|
|
|
|
|
|
|
for j, position in enumerate(
|
|
|
|
|
mm_positions[i] if mm_positions is not None else []
|
|
|
|
|
):
|
|
|
|
|
if mm_hashes_list is not None:
|
|
|
|
|
identifier = mm_hashes_list[i][j]
|
|
|
|
|
|
|
|
|
|
# Verify if position length is identical
|
|
|
|
|
position_length = position.length
|
|
|
|
|
if identifier in seen_hashes:
|
|
|
|
|
assert seen_hashes[identifier] == position_length, (
|
|
|
|
|
f"mm_hash '{identifier}' has inconsistent position lengths: "
|
|
|
|
|
f"previously {seen_hashes[identifier]}, now {position_length} "
|
|
|
|
|
f"at request {i}, position {j}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
seen_hashes[identifier] = position_length
|
|
|
|
|
else:
|
|
|
|
|
# Unique dummy hash for each mm item
|
2025-08-29 03:36:57 -07:00
|
|
|
identifier = f"hash{i}_{j}"
|
2025-11-12 10:58:33 +08:00
|
|
|
mm_feature = MultiModalFeatureSpec(
|
|
|
|
|
data=MultiModalKwargsItem.dummy("dummy_m"),
|
|
|
|
|
mm_position=position,
|
|
|
|
|
identifier=identifier,
|
|
|
|
|
modality="image",
|
|
|
|
|
)
|
|
|
|
|
mm_features.append(mm_feature)
|
2025-08-29 03:36:57 -07:00
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
prompt_token_ids = (
|
|
|
|
|
[starting_idx] * num_tokens
|
|
|
|
|
if same_prompt
|
|
|
|
|
else [i + starting_idx] * num_tokens
|
|
|
|
|
)
|
2025-06-23 06:18:08 +03:00
|
|
|
request = Request(
|
2025-11-12 10:58:33 +08:00
|
|
|
request_id=req_ids[i],
|
|
|
|
|
prompt_token_ids=prompt_token_ids,
|
2025-06-23 06:18:08 +03:00
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
pooling_params=None,
|
2025-08-29 03:36:57 -07:00
|
|
|
mm_features=mm_features if mm_features else None,
|
2025-06-23 06:18:08 +03:00
|
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
|
|
|
arrival_time=arrival_times[i],
|
|
|
|
|
priority=priorities[i],
|
2025-11-12 10:58:33 +08:00
|
|
|
block_hasher=block_hasher,
|
2025-06-23 06:18:08 +03:00
|
|
|
)
|
|
|
|
|
requests.append(request)
|
|
|
|
|
return requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_basic_ordering():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that requests are scheduled in priority order
|
2025-06-23 06:18:08 +03:00
|
|
|
(lower value = higher priority)."""
|
|
|
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
|
|
|
|
|
|
# Create requests with different priorities
|
|
|
|
|
# Priority 0 (highest), 1, 2 (lowest)
|
|
|
|
|
priorities = [2, 0, 1] # Add in non-priority order
|
|
|
|
|
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=3, priorities=priorities, arrival_times=arrival_times
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add requests in non-priority order
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule and verify priority order
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule all requests since they fit in budget
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 3
|
|
|
|
|
|
|
|
|
|
# Verify they are scheduled in priority order:
|
|
|
|
|
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
|
|
|
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
|
|
|
assert scheduled_req_ids == ["1", "2", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_arrival_time_tiebreaker():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that arrival time is used
|
2025-06-23 06:18:08 +03:00
|
|
|
as tiebreaker when priorities are equal."""
|
|
|
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
|
|
|
|
|
|
# Create requests with same priority but different arrival times
|
|
|
|
|
priorities = [1, 1, 1] # All same priority
|
|
|
|
|
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=3, priorities=priorities, arrival_times=arrival_times
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add requests in non-arrival order
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule and verify arrival time order
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule all requests since they fit in budget
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 3
|
|
|
|
|
|
|
|
|
|
# Verify they are scheduled in arrival time order:
|
|
|
|
|
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
|
|
|
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
|
|
|
assert scheduled_req_ids == ["1", "2", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_mixed_priority_and_arrival():
|
|
|
|
|
"""Test priority scheduling with mixed priorities and arrival times."""
|
|
|
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
|
|
|
|
|
|
# Create requests with mixed priorities and arrival times
|
|
|
|
|
priorities = [2, 1, 1, 0] # Mixed priorities
|
|
|
|
|
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=4, priorities=priorities, arrival_times=arrival_times
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule and verify order
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule all requests since they fit in budget
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 4
|
|
|
|
|
|
|
|
|
|
# Expected order:
|
|
|
|
|
# 1. req_3 (priority 0, arrival 4.0)
|
|
|
|
|
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
|
|
|
|
|
# 3. req_1 (priority 1, arrival 3.0)
|
|
|
|
|
# 4. req_0 (priority 2, arrival 1.0)
|
|
|
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
|
|
|
assert scheduled_req_ids == ["3", "2", "1", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_preemption():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that priority scheduling preempts
|
2025-06-23 06:18:08 +03:00
|
|
|
lower priority requests when memory is constrained."""
|
|
|
|
|
# Create scheduler with very limited memory to force preemption
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=3, # Allow multiple requests
|
|
|
|
|
max_num_batched_tokens=200,
|
|
|
|
|
num_blocks=6, # Very limited blocks to force memory pressure
|
|
|
|
|
block_size=16, # Standard block size
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create initial low-priority requests that will consume most memory
|
|
|
|
|
low_priority_requests = create_requests_with_priority(
|
|
|
|
|
num_requests=2,
|
|
|
|
|
priorities=[5, 5], # Low priority
|
|
|
|
|
arrival_times=[1.0, 2.0],
|
|
|
|
|
num_tokens=30, # Large enough to consume significant memory
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add and schedule low priority requests
|
|
|
|
|
for request in low_priority_requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 2
|
|
|
|
|
|
|
|
|
|
# Simulate model execution to move requests to running state
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in low_priority_requests],
|
|
|
|
|
req_id_to_index={
|
|
|
|
|
req.request_id: i for i, req in enumerate(low_priority_requests)
|
|
|
|
|
},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100] for _ in low_priority_requests],
|
2025-06-23 06:18:08 +03:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# Verify both requests are running
|
|
|
|
|
assert len(scheduler.running) == 2
|
|
|
|
|
|
|
|
|
|
# Now add a high-priority request that requires memory allocation
|
|
|
|
|
# This should trigger preemption due to memory constraints
|
|
|
|
|
high_priority_request = create_requests_with_priority(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
priorities=[0], # High priority
|
|
|
|
|
arrival_times=[3.0],
|
|
|
|
|
num_tokens=30, # Large enough to require significant memory
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(high_priority_request)
|
|
|
|
|
|
|
|
|
|
# Schedule again - this should trigger
|
|
|
|
|
# preemption when trying to allocate memory
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Due to the scheduler's design, if preemption happens
|
|
|
|
|
# during running request scheduling,
|
|
|
|
|
# waiting requests won't be scheduled in the same step
|
|
|
|
|
# Let's check if preemption occurred by looking at the waiting queue
|
|
|
|
|
|
|
|
|
|
# If preemption happened, we should see requests in the
|
|
|
|
|
# waiting queue
|
|
|
|
|
if len(scheduler.waiting) > 1: # high priority + preempted request
|
|
|
|
|
# Preemption occurred - verify the high priority request
|
|
|
|
|
# gets scheduled next
|
|
|
|
|
output2 = scheduler.schedule()
|
|
|
|
|
assert len(output2.scheduled_new_reqs) == 1
|
|
|
|
|
# High priority request
|
|
|
|
|
assert output2.scheduled_new_reqs[0].req_id == "0"
|
|
|
|
|
else:
|
|
|
|
|
# No preemption needed - all requests fit
|
|
|
|
|
# This is also valid behavior if memory allows
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
# High priority request
|
|
|
|
|
assert output.scheduled_new_reqs[0].req_id == "0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_no_preemption_when_space_available():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that preemption doesn't happen
|
2025-06-23 06:18:08 +03:00
|
|
|
when there's space for new requests."""
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=3, # Allow 3 concurrent requests
|
|
|
|
|
max_num_batched_tokens=200, # Sufficient token budget
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add two low-priority running requests
|
|
|
|
|
low_priority_requests = create_requests_with_priority(
|
|
|
|
|
num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for request in low_priority_requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in low_priority_requests],
|
|
|
|
|
req_id_to_index={
|
|
|
|
|
req.request_id: i for i, req in enumerate(low_priority_requests)
|
|
|
|
|
},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100] for _ in low_priority_requests],
|
2025-06-23 06:18:08 +03:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# Add high-priority request
|
|
|
|
|
high_priority_request = create_requests_with_priority(
|
|
|
|
|
num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(high_priority_request)
|
|
|
|
|
|
|
|
|
|
# Schedule - should not preempt since there's space
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule the new request without preemption
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert len(scheduler.running) == 3 # All three requests running
|
|
|
|
|
assert len(scheduler.waiting) == 0 # No requests waiting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_preemption_victim_selection():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that the correct victim is selected for
|
2025-06-23 06:18:08 +03:00
|
|
|
preemption based on priority and arrival time."""
|
|
|
|
|
# This test verifies the priority-based victim selection logic
|
|
|
|
|
# by checking the waiting queue order after adding requests with different
|
|
|
|
|
# priorities
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=1, # Force sequential processing to test priority order
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create requests with different priorities
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=3,
|
|
|
|
|
priorities=[3, 2, 0], # Different priorities: low, medium, high
|
|
|
|
|
arrival_times=[1.0, 2.0, 3.0],
|
|
|
|
|
num_tokens=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add all requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule - should only schedule the highest priority request
|
|
|
|
|
# (req_2, priority 0)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
|
|
|
|
|
|
|
|
|
|
# Verify the waiting queue has the remaining requests in priority order
|
|
|
|
|
assert len(scheduler.waiting) == 2
|
|
|
|
|
|
|
|
|
|
# Extract waiting requests and verify priority order
|
|
|
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
|
|
|
|
|
|
waiting_priorities = [req.priority for req in waiting_requests]
|
|
|
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
|
|
|
|
|
|
# Should be req_1 (priority 2) then req_0 (priority 3)
|
|
|
|
|
assert waiting_priorities == [2, 3]
|
|
|
|
|
assert waiting_req_ids == ["1", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_equal_priority_preemption():
|
|
|
|
|
"""Test arrival time tiebreaker when requests have equal priority."""
|
|
|
|
|
# This test verifies that arrival time is used as a tiebreaker for equal
|
|
|
|
|
# priorities
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=1, # Force sequential processing
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create requests with same priority but different arrival times
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=3,
|
|
|
|
|
priorities=[2, 2, 2], # Same priority
|
|
|
|
|
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
|
|
|
|
|
num_tokens=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add all requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule - should schedule the request with earliest arrival time
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
|
|
|
|
|
|
|
|
|
|
# Verify the waiting queue has remaining requests in arrival time order
|
|
|
|
|
assert len(scheduler.waiting) == 2
|
|
|
|
|
|
|
|
|
|
# Extract waiting requests and verify arrival time order
|
|
|
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
|
|
|
|
|
|
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
|
|
|
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
|
|
|
|
|
|
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
|
|
|
|
|
assert waiting_arrival_times == [2.0, 3.0]
|
|
|
|
|
assert waiting_req_ids == ["2", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_waiting_queue_order():
|
|
|
|
|
"""Test that the waiting queue maintains priority order."""
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=1, # Only one request can run at a time
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create multiple requests with different priorities
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=4,
|
|
|
|
|
priorities=[3, 1, 2, 0], # Mixed priorities
|
|
|
|
|
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
|
|
|
|
num_tokens=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add all requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule - should only schedule the highest priority request
|
|
|
|
|
# (req_3, priority 0)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert output.scheduled_new_reqs[0].req_id == "3"
|
|
|
|
|
|
|
|
|
|
# Verify waiting queue has remaining requests in priority order
|
|
|
|
|
assert len(scheduler.waiting) == 3
|
|
|
|
|
|
|
|
|
|
# Extract requests from waiting queue
|
|
|
|
|
# (it's a heap, so we need to pop to see order)
|
|
|
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
|
|
|
|
|
|
waiting_priorities = [req.priority for req in waiting_requests]
|
|
|
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
|
|
|
|
|
|
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
|
|
|
|
|
assert waiting_req_ids == ["1", "2", "0"]
|
|
|
|
|
assert waiting_priorities == [1, 2, 3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_fcfs_fallback():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that FCFS behavior is maintained when all
|
2025-06-23 06:18:08 +03:00
|
|
|
requests have same priority."""
|
|
|
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
|
|
|
|
|
|
# Create requests with same priority but different arrival times
|
|
|
|
|
priorities = [1, 1, 1, 1] # All same priority
|
|
|
|
|
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=4, priorities=priorities, arrival_times=arrival_times
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule all requests in arrival time order
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 4
|
|
|
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
|
|
|
|
|
|
|
|
# Expected order by arrival time:
|
|
|
|
|
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
|
|
|
|
|
assert scheduled_req_ids == ["1", "3", "2", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_with_limited_slots():
|
|
|
|
|
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=2, # Only allow 2 concurrent requests
|
|
|
|
|
max_num_batched_tokens=1000, # Plenty of token budget
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create requests with different priorities
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=4,
|
|
|
|
|
priorities=[3, 1, 2, 0], # Mixed priorities
|
|
|
|
|
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
|
|
|
|
num_tokens=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add all requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule - should only schedule the 2 highest priority requests
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 2
|
|
|
|
|
|
|
|
|
|
# Should schedule req_3 (priority 0) and req_1 (priority 1)
|
|
|
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
|
|
|
assert "3" in scheduled_req_ids # Priority 0
|
|
|
|
|
assert "1" in scheduled_req_ids # Priority 1
|
|
|
|
|
|
|
|
|
|
# Remaining requests should be in waiting queue in priority order
|
|
|
|
|
assert len(scheduler.waiting) == 2
|
|
|
|
|
|
|
|
|
|
# Extract waiting requests and verify order
|
|
|
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
|
waiting_priorities = [req.priority for req in waiting_requests]
|
|
|
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
|
|
|
|
|
|
# Should be req_2 (priority 2) then req_0 (priority 3)
|
|
|
|
|
assert waiting_priorities == [2, 3]
|
|
|
|
|
assert waiting_req_ids == ["2", "0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_priority_scheduling_heap_property():
|
2025-07-04 03:05:49 -04:00
|
|
|
"""Test that the waiting queue maintains heap
|
2025-06-23 06:18:08 +03:00
|
|
|
property for priority scheduling."""
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=1, # Only one request can run at a time
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add requests in random priority order
|
|
|
|
|
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
|
|
|
|
|
arrival_times = [float(i) for i in range(len(priorities))]
|
|
|
|
|
requests = create_requests_with_priority(
|
|
|
|
|
num_requests=len(priorities),
|
|
|
|
|
priorities=priorities,
|
|
|
|
|
arrival_times=arrival_times,
|
|
|
|
|
num_tokens=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add all requests
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Schedule one request at a time and verify priority order
|
|
|
|
|
scheduled_priorities = []
|
|
|
|
|
|
|
|
|
|
while scheduler.waiting:
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
if output.scheduled_new_reqs:
|
|
|
|
|
req = output.scheduled_new_reqs[0]
|
|
|
|
|
scheduled_priorities.append(requests[int(req.req_id)].priority)
|
|
|
|
|
|
|
|
|
|
# Simulate completion to make room for next request
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.req_id],
|
|
|
|
|
req_id_to_index={req.req_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100]],
|
2025-06-23 06:18:08 +03:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# Finish the request to make room for the next one
|
|
|
|
|
scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED)
|
|
|
|
|
|
|
|
|
|
# Verify requests were scheduled in priority order (lowest value first)
|
|
|
|
|
expected_priorities = sorted(priorities)
|
|
|
|
|
assert scheduled_priorities == expected_priorities
|
2025-07-04 03:05:49 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_schedule_skip_tokenizer_init():
|
|
|
|
|
scheduler = create_scheduler(skip_tokenizer_init=True)
|
|
|
|
|
requests = create_requests(num_requests=5)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_schedule_skip_tokenizer_init_structured_output_request():
|
|
|
|
|
scheduler = create_scheduler(skip_tokenizer_init=True)
|
2025-09-18 05:20:27 -04:00
|
|
|
structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
|
2025-07-04 03:05:49 -04:00
|
|
|
sampling_params = SamplingParams(
|
|
|
|
|
ignore_eos=False,
|
|
|
|
|
max_tokens=16,
|
2025-09-18 05:20:27 -04:00
|
|
|
structured_outputs=structured_outputs_params,
|
2025-07-04 03:05:49 -04:00
|
|
|
)
|
|
|
|
|
request = Request(
|
|
|
|
|
request_id="0",
|
|
|
|
|
prompt_token_ids=[0, 1],
|
2025-08-29 03:36:57 -07:00
|
|
|
mm_features=None,
|
2025-07-04 03:05:49 -04:00
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
pooling_params=None,
|
|
|
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
|
|
|
)
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
2025-08-27 17:44:52 -07:00
|
|
|
|
|
|
|
|
|
2025-11-12 10:58:33 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
|
|
|
|
|
)
|
|
|
|
|
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
|
|
|
|
use_ec_connector, ec_role
|
|
|
|
|
):
|
2025-08-27 17:44:52 -07:00
|
|
|
"""Test that priority scheduling preempts lower priority requests
|
|
|
|
|
when out of KV cache space."""
|
|
|
|
|
# Create scheduler with very limited memory to force preemption
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
max_num_seqs=2, # Allow multiple requests
|
|
|
|
|
max_num_batched_tokens=200,
|
|
|
|
|
num_blocks=5, # Can hold 64 tokens (first block is null)
|
|
|
|
|
block_size=16, # Standard block size
|
2025-10-09 02:43:31 -04:00
|
|
|
use_kv_connector=True,
|
2025-11-12 10:58:33 +08:00
|
|
|
# encoder connector should not affect test results
|
|
|
|
|
use_ec_connector=use_ec_connector,
|
|
|
|
|
ec_role=ec_role,
|
2025-08-27 17:44:52 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create a request and schedule it
|
|
|
|
|
request_low = create_requests_with_priority(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
priorities=[1],
|
|
|
|
|
arrival_times=[0.0],
|
|
|
|
|
num_tokens=30,
|
|
|
|
|
starting_idx=0,
|
|
|
|
|
)[0]
|
|
|
|
|
scheduler.add_request(request_low)
|
2025-10-09 02:43:31 -04:00
|
|
|
# 1st schedule
|
2025-08-27 17:44:52 -07:00
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
|
2025-10-09 02:43:31 -04:00
|
|
|
# Simulate model execution - 1st decode
|
2025-08-27 17:44:52 -07:00
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request_low.request_id],
|
|
|
|
|
req_id_to_index={request_low.request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100]],
|
2025-08-27 17:44:52 -07:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# Create a high priority request and schedule it
|
|
|
|
|
request_high = create_requests_with_priority(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
priorities=[0],
|
|
|
|
|
arrival_times=[1.0],
|
|
|
|
|
num_tokens=32,
|
|
|
|
|
starting_idx=1,
|
|
|
|
|
)[0]
|
|
|
|
|
scheduler.add_request(request_high)
|
2025-10-09 02:43:31 -04:00
|
|
|
# 2nd schedule
|
2025-08-27 17:44:52 -07:00
|
|
|
output = scheduler.schedule()
|
|
|
|
|
# KV cache should be full at this point
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 2
|
|
|
|
|
|
2025-10-09 02:43:31 -04:00
|
|
|
# Simulate model execution - 2nd decode
|
2025-08-27 17:44:52 -07:00
|
|
|
requests = [request_low, request_high]
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100] for _ in requests],
|
2025-08-27 17:44:52 -07:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
2025-10-09 02:43:31 -04:00
|
|
|
# 3rd schedule - this should trigger preemption
|
2025-08-27 17:44:52 -07:00
|
|
|
# req_low needs 32 tokens = 2 blocks
|
|
|
|
|
# req_high needs 33 tokens = 3 blocks
|
|
|
|
|
# so doesn't fit in 4 blocks.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should have preempted req_low
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
|
|
|
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
|
2025-10-09 02:43:31 -04:00
|
|
|
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
|
2025-08-27 17:44:52 -07:00
|
|
|
assert len(scheduler.waiting) == 1
|
2025-09-30 10:17:49 +02:00
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
|
2025-10-09 02:43:31 -04:00
|
|
|
# Simulate model execution - 3rd decode
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[], [100]],
|
2025-10-09 02:43:31 -04:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
# Finish the requests to make room for the preempted requests to resume
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
|
|
|
|
|
|
|
|
|
|
# 4th Schedule - this should trigger the resumption
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
scheduled_cached_reqs = output.scheduled_cached_reqs
|
|
|
|
|
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
|
|
|
|
|
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
assert scheduled_cached_reqs.num_reqs == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
|
|
|
|
|
# Preempted request resumed in scheduled_cached_reqs
|
|
|
|
|
assert len(resumed_from_preemption) == 1
|
|
|
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
|
|
|
|
|
assert resumed_from_preemption[0]
|
|
|
|
|
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
|
|
|
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
|
|
|
|
|
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
|
|
|
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
|
|
|
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
|
|
|
|
|
|
2025-09-30 10:17:49 +02:00
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
|
|
|
|
|
[
|
|
|
|
|
(True, False, True),
|
|
|
|
|
(False, False, False),
|
|
|
|
|
# Encoder-decoder models should always have it disabled
|
|
|
|
|
(False, True, False),
|
|
|
|
|
(True, True, False),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_chunked_prefill_disabled_for_encoder_decoder(
|
|
|
|
|
enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool
|
|
|
|
|
) -> None:
|
2025-10-04 02:44:49 -07:00
|
|
|
"""Validate that chunked prefill is appropriately disabled for
|
2025-09-30 10:17:49 +02:00
|
|
|
encoder-decoder models."""
|
|
|
|
|
scheduler_config = SchedulerConfig(
|
|
|
|
|
enable_chunked_prefill=enable_chunked_prefill,
|
|
|
|
|
is_encoder_decoder=is_encoder_decoder,
|
2025-11-14 22:41:02 +08:00
|
|
|
# Must <= max_num_batched_tokens if chunked prefill is disabled
|
|
|
|
|
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
2025-09-30 10:17:49 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# `is_encoder_decoder` should only be used during construction
|
|
|
|
|
# of the config, and otherwise stored in the model config.
|
|
|
|
|
assert "is_encoder_decoder" not in vars(scheduler_config)
|
|
|
|
|
assert "is_encoder_decoder" not in [
|
|
|
|
|
f.name for f in dataclasses.fields(scheduler_config)
|
|
|
|
|
]
|
|
|
|
|
_validate_chunked_prefill_settings_for_encoder_decoder(
|
|
|
|
|
scheduler_config, is_encoder_decoder, expect_enabled
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Ensure it is retained in VllmConfig, even after its post-init.
|
|
|
|
|
vllm_config = VllmConfig(scheduler_config=scheduler_config)
|
|
|
|
|
_validate_chunked_prefill_settings_for_encoder_decoder(
|
|
|
|
|
vllm_config.scheduler_config, is_encoder_decoder, expect_enabled
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_chunked_prefill_settings_for_encoder_decoder(
|
|
|
|
|
scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool
|
|
|
|
|
) -> None:
|
2025-10-04 02:44:49 -07:00
|
|
|
"""Validate chunked prefill settings in the scheduler config for
|
2025-09-30 10:17:49 +02:00
|
|
|
encoder-decoder models."""
|
|
|
|
|
assert scheduler_config.enable_chunked_prefill is expect_enabled
|
|
|
|
|
if is_encoder_decoder:
|
|
|
|
|
# Encoder-decoder models should automatically disable chunked multimodal
|
|
|
|
|
# inputs as well
|
|
|
|
|
assert scheduler_config.disable_chunked_mm_input is not expect_enabled
|
|
|
|
|
if is_encoder_decoder and not expect_enabled:
|
|
|
|
|
assert scheduler_config.long_prefill_token_threshold == 0
|
2025-11-12 10:58:33 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests start
|
|
|
|
|
# NOTE: In E->P->D disagg case, both KV and EC Connector works in P instance
|
|
|
|
|
# Unless specify, the existence of KV Connector should not affect any test results
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assert_right_encoder_cache_allocated(
|
|
|
|
|
scheduler: Scheduler,
|
|
|
|
|
hashes_to_check: list[str] | None = None,
|
|
|
|
|
requests: list[Request] | None = None,
|
|
|
|
|
expected_total_allocated: int | None = None,
|
|
|
|
|
):
|
|
|
|
|
"""Check whether encoder cache is allocated correctly."""
|
|
|
|
|
encoder_cache_manager = scheduler.encoder_cache_manager
|
|
|
|
|
|
|
|
|
|
# Verify encoder cache manager exists
|
|
|
|
|
assert encoder_cache_manager is not None, "Encoder cache manager should exist"
|
|
|
|
|
|
|
|
|
|
# Verify number of cache
|
|
|
|
|
if expected_total_allocated is not None:
|
|
|
|
|
assert len(encoder_cache_manager.cached) == expected_total_allocated
|
|
|
|
|
if expected_total_allocated == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Verify each request with MM data is in cache
|
|
|
|
|
cached_hashes = set(encoder_cache_manager.cached.keys())
|
|
|
|
|
|
|
|
|
|
if hashes_to_check:
|
|
|
|
|
missed_hashes = set(hashes_to_check) - cached_hashes
|
|
|
|
|
assert not missed_hashes, (
|
|
|
|
|
f"Miss hashes: {missed_hashes} "
|
|
|
|
|
f"Existing encoder cache: {encoder_cache_manager.cached}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for req in requests if requests is not None else []:
|
|
|
|
|
if req.mm_features:
|
|
|
|
|
mm_hashes = [f.identifier for f in req.mm_features]
|
|
|
|
|
req_hashes = set(mm_hashes) # unique hashes set
|
|
|
|
|
missed_hashes = req_hashes - cached_hashes
|
|
|
|
|
assert not missed_hashes, (
|
|
|
|
|
f"Miss hashes in cache for request {req.request_id}: {missed_hashes} "
|
|
|
|
|
f"Existing encoder cache: {encoder_cache_manager.cached}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assert_right_ec_connector_metadata(
|
|
|
|
|
output: SchedulerOutput,
|
|
|
|
|
mm_features_list: list[MultiModalFeatureSpec],
|
|
|
|
|
):
|
|
|
|
|
"""Verify that ECConnector metadata EXACTLY matches the input MM data"""
|
|
|
|
|
# Get the connector metadata
|
|
|
|
|
metadata = output.ec_connector_metadata
|
|
|
|
|
|
|
|
|
|
# Create lookup dictionaries for efficient access
|
|
|
|
|
metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas}
|
|
|
|
|
|
|
|
|
|
# Check all required identifiers exist in metadata; and no extra
|
|
|
|
|
# In ECSharedStorageConnector format
|
|
|
|
|
# NOTE: even having same identifier, the mm_features can be different
|
|
|
|
|
# since their mm_position can be in different offsets, etc
|
|
|
|
|
identifiers_dict = {f.identifier for f in mm_features_list}
|
|
|
|
|
assert set(metadata_dict.keys()) == identifiers_dict
|
|
|
|
|
|
|
|
|
|
# Verify the info matches
|
|
|
|
|
for i, mm_feature in enumerate(mm_features_list):
|
|
|
|
|
identifier = mm_feature.identifier
|
|
|
|
|
assert metadata_dict[identifier].mm_hash == identifier
|
|
|
|
|
assert metadata_dict[identifier].num_token == mm_feature.mm_position.length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assert_right_encoder_inputs(
|
|
|
|
|
output: SchedulerOutput,
|
|
|
|
|
check_exist: bool | None = True,
|
|
|
|
|
requests: list[Request] | None = None,
|
|
|
|
|
expected_encoder_inputs: list[list[int]] | None = None,
|
|
|
|
|
expected_total_reqs: int | None = None,
|
|
|
|
|
):
|
|
|
|
|
"""Verify that requests/mm_hashes should (not) in scheduled encoder input
|
|
|
|
|
If check_exist is False, this function returns True
|
|
|
|
|
if requests are NOT in encoder inputs"""
|
|
|
|
|
|
|
|
|
|
# Get the scheduled encoder inputs
|
|
|
|
|
# NOTE: scheduled_encoder_inputs is a dictionary with request id as key
|
|
|
|
|
scheduled_encoder_inputs = output.scheduled_encoder_inputs
|
|
|
|
|
|
|
|
|
|
# Check if scheduled_encoder_inputs is empty as expected
|
|
|
|
|
if expected_total_reqs is not None:
|
|
|
|
|
assert len(scheduled_encoder_inputs) == expected_total_reqs
|
|
|
|
|
if expected_total_reqs == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Number of expected enocder inputs should match number of requests
|
|
|
|
|
if expected_encoder_inputs:
|
|
|
|
|
assert check_exist and requests is not None # only support expect input exist
|
|
|
|
|
assert len(requests) == len(expected_encoder_inputs)
|
|
|
|
|
|
|
|
|
|
# Check request (not) exist as expected
|
|
|
|
|
for i, request in enumerate(requests if requests is not None else []):
|
|
|
|
|
assert (request.request_id in scheduled_encoder_inputs) is check_exist, (
|
|
|
|
|
f"Request {request.id} presence mismatch: expected {check_exist}, "
|
|
|
|
|
f"got {request.id in scheduled_encoder_inputs}"
|
|
|
|
|
)
|
|
|
|
|
if expected_encoder_inputs:
|
|
|
|
|
scheduled_encoder_input = scheduled_encoder_inputs[request.request_id]
|
|
|
|
|
assert scheduled_encoder_input == expected_encoder_inputs[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_scheduler_no_ec_connector_by_default():
|
|
|
|
|
"""Test scheduler doesn't have EC connector by default."""
|
|
|
|
|
scheduler = create_scheduler()
|
|
|
|
|
assert scheduler.ec_connector is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_text_only_request(use_kv_connector):
|
|
|
|
|
"""Test text-only requests don't allocate encoder cache."""
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
NUM_PROMPT_TOKENS = 100
|
|
|
|
|
|
|
|
|
|
# Create text-only request (no mm_positions)
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=NUM_PROMPT_TOKENS,
|
|
|
|
|
)
|
|
|
|
|
assert not requests[0].mm_features # No MM data
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(requests[0])
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
|
|
|
|
|
# Scheduled tokens should equal prompt tokens exactly
|
|
|
|
|
scheduled = output.num_scheduled_tokens[requests[0].request_id]
|
|
|
|
|
assert scheduled == NUM_PROMPT_TOKENS, (
|
|
|
|
|
f"Text-only should schedule {NUM_PROMPT_TOKENS}, got {scheduled}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Encoder cache should be empty
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, expected_total_allocated=0)
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry no metadata
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
|
|
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_cache_hit_external_load(use_kv_connector):
|
|
|
|
|
"""Test ec_consumer loads from external cache when hit.
|
|
|
|
|
A normal basic operation for EPD disaggrgation"""
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
# kv connector should not effect test results
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create MM request
|
|
|
|
|
NUM_TOKENS = 200 # NOTE: includes mm tokens
|
|
|
|
|
NUM_ENCODER_TOKENS = 100
|
|
|
|
|
mm_hashes_list = [["hash_test1"]]
|
|
|
|
|
mm_positions = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS)]]
|
|
|
|
|
|
|
|
|
|
request = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
|
|
|
|
mm_hashes_list=mm_hashes_list,
|
|
|
|
|
mm_positions=mm_positions,
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# Mock cache hit - encoder cache exists externally
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
|
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule prompt tokens
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request.request_id]
|
|
|
|
|
assert scheduled_tokens == NUM_TOKENS
|
|
|
|
|
|
|
|
|
|
# Should called update_state_after_alloc for external load
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request, 0)
|
|
|
|
|
|
|
|
|
|
# Encoder cache should contain mm items from request
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request])
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata of request
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=request.mm_features)
|
|
|
|
|
|
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
|
|
|
|
|
"""Test consumer can compute encoder locally when cache miss (fallback)."""
|
|
|
|
|
# encoder cache itself if it doesn't receive it from external storage
|
|
|
|
|
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Verify consumer role
|
|
|
|
|
assert scheduler.ec_connector is not None
|
|
|
|
|
assert not scheduler.ec_connector.is_producer
|
|
|
|
|
|
|
|
|
|
# Create MM request
|
|
|
|
|
request_mm_missed = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=200, # Total (including 100 MM)
|
|
|
|
|
mm_positions=[[PlaceholderRange(offset=0, length=100)]], # 100 MM tokens
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# Mock cache miss - encoder cache doesn't exist externally
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(request_mm_missed)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# SCHEDULER should decide to compute encoder locally (fallback)
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
|
|
|
|
|
# Should schedule full prompt tokens
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request_mm_missed.request_id]
|
|
|
|
|
assert scheduled_tokens == 200, (
|
|
|
|
|
f"Expected 200 tokens on cache miss, got {scheduled_tokens}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Encoder cache should contain mm items from request
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_mm_missed])
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry no metadata (missed cache)
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
|
|
|
|
|
|
# Scheduled encoder input contain mm for request_mm_missed
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
requests=[request_mm_missed],
|
|
|
|
|
expected_encoder_inputs=[[0]], # index 0 of the mm item
|
|
|
|
|
expected_total_reqs=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Then MODEL_RUNNER will execute the encoder and cache the result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
|
|
|
|
"""Test consumer with partial cache hit (local & connector) with 2 requests."""
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create MM request
|
|
|
|
|
NUM_TOKENS_1 = 300 # NOTE: includes mm tokens
|
|
|
|
|
NUM_ENCODER_TOKENS_1 = 50
|
|
|
|
|
mm_hashes_list_1 = [["hash1_A", "hash1_B", "hash1_A", "hash1_F"]]
|
|
|
|
|
mm_positions_1 = [
|
|
|
|
|
[
|
|
|
|
|
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1),
|
|
|
|
|
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_1),
|
|
|
|
|
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_1),
|
|
|
|
|
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
|
|
|
|
|
]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Create request with 4 MM items, with 2 identical items
|
|
|
|
|
request1 = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=NUM_TOKENS_1,
|
|
|
|
|
mm_hashes_list=mm_hashes_list_1,
|
|
|
|
|
mm_positions=mm_positions_1,
|
|
|
|
|
max_tokens=1, # For simplicity
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True])
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
|
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(request1)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should schedule all tokens
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
|
|
|
|
|
assert scheduled_tokens == NUM_TOKENS_1
|
|
|
|
|
|
|
|
|
|
# Encoder cache should contain all mm items from request
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request1])
|
|
|
|
|
|
|
|
|
|
# Should have called update_state_after_alloc for external load
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called()
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata for 2nd and 4th mm item
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=[request1.mm_features[1], request1.mm_features[3]]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Should schedule ONLY 1 encoder input (index 0), no repeat for identical items
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
requests=[request1],
|
|
|
|
|
expected_encoder_inputs=[[0]], # index 0 of the mm item ONLY
|
|
|
|
|
expected_total_reqs=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Simulate model execution 1 step
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request1.request_id],
|
|
|
|
|
req_id_to_index={request1.request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100]],
|
2025-11-12 10:58:33 +08:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# request1 is finished after outputing 1 token
|
|
|
|
|
# Finish request
|
|
|
|
|
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
|
|
|
|
|
|
|
|
|
|
# Create another request with 4 MM items
|
|
|
|
|
NUM_TOKENS_2 = 400
|
|
|
|
|
NUM_ENCODER_TOKENS_2 = 50
|
|
|
|
|
mm_hashes_list_2 = [["hash1_C", "hash1_D", "hash1_E", "hash1_A"]]
|
|
|
|
|
mm_positions_2 = [
|
|
|
|
|
[
|
|
|
|
|
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
request2 = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=NUM_TOKENS_2,
|
|
|
|
|
mm_hashes_list=mm_hashes_list_2,
|
|
|
|
|
mm_positions=mm_positions_2,
|
|
|
|
|
max_tokens=1, # For simplicity
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True])
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(request2)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Check
|
|
|
|
|
# Should schedule all tokens
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
|
|
|
|
|
assert scheduled_tokens == 400
|
|
|
|
|
|
|
|
|
|
# Encoder cache should contain all mm items from request2
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
|
|
|
|
|
|
|
|
|
|
# Should call update_state_after_alloc for hash1_C, ONLY
|
|
|
|
|
# hash1_A should not be loaded from connector
|
|
|
|
|
# since it's computed in last request & exist in local cache
|
|
|
|
|
# Order of getting encoder cache should be: local cache -> connector-> compute
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 0)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
|
|
|
|
|
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata for hash1_C only (index 0)
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=[request2.mm_features[0]]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Should schedule 2 encoder input hash1_D and hash1_E (index 1, 2)
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
requests=[request2],
|
|
|
|
|
expected_encoder_inputs=[[1, 2]],
|
|
|
|
|
expected_total_reqs=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
max_num_seqs=10, # allow multiple requests
|
|
|
|
|
max_num_batched_tokens=2048,
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
mm_hashes_list = [[f"hash_{i}"] for i in range(10)]
|
|
|
|
|
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=10,
|
|
|
|
|
num_tokens=200,
|
|
|
|
|
mm_hashes_list=mm_hashes_list,
|
|
|
|
|
mm_positions=mm_positions,
|
|
|
|
|
)
|
|
|
|
|
for request in requests:
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
|
|
|
|
|
# Set up to test different encoder cache exsistence scenario after preemption
|
|
|
|
|
# Order of getting encoder cache should be: local cache -> connector-> compute
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
|
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cache_exist == "local":
|
|
|
|
|
# Allocate cache to cache manager manually to mimick
|
|
|
|
|
for req in requests:
|
|
|
|
|
scheduler.encoder_cache_manager.allocate(req, 0)
|
|
|
|
|
else:
|
|
|
|
|
# Make sure local encoder cache empty
|
|
|
|
|
scheduler.encoder_cache_manager.cached = {}
|
|
|
|
|
|
|
|
|
|
if cache_exist == "connector_only":
|
|
|
|
|
# Cache exist in ec_connector
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
|
|
|
elif cache_exist == "no_where":
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
|
|
|
|
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
|
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
|
|
|
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
|
|
|
|
|
|
|
|
|
## Encoder-cache-specific checks:
|
|
|
|
|
# mm_hashes of requests exist in cache after scheduling for all scenario
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=requests)
|
|
|
|
|
|
|
|
|
|
# Should only call update_state_after_alloc when loaded externally
|
|
|
|
|
if cache_exist == "connector_only":
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
|
|
|
requests[-1], 0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Concat mm_features for the 10 requests together
|
|
|
|
|
mm_features_list = [feature for req in requests for feature in req.mm_features]
|
|
|
|
|
|
|
|
|
|
# Check metadata should contain mm data for all 10 requests
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=mm_features_list)
|
|
|
|
|
else:
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
|
|
|
|
|
# ECConnector should carry no metadata
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
|
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Should only schedule encoder input when cache is not found anywhere
|
|
|
|
|
if cache_exist == "no_where":
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
requests=requests,
|
|
|
|
|
expected_encoder_inputs=[[0] for _ in range(10)],
|
|
|
|
|
expected_total_reqs=10,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_unable_to_allocate(use_kv_connector):
|
|
|
|
|
"""
|
|
|
|
|
Test whether scheduler with ECConnector is able to handle
|
|
|
|
|
unable to allocate (run out of blocks).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Setup Scheduler With Mock External Cache Hit.
|
|
|
|
|
BLOCK_SIZE = 4
|
|
|
|
|
NUM_BLOCKS = 10
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
num_blocks=NUM_BLOCKS,
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Mock ec_connector load external cache behavior
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
|
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create two requests. The second request will not be able to
|
|
|
|
|
# allocate slots because it will not have enough blocks.
|
|
|
|
|
NUM_REQUESTS = 2
|
|
|
|
|
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
|
|
|
|
|
MAX_TOKENS = 2
|
|
|
|
|
requests = create_requests(
|
|
|
|
|
num_requests=NUM_REQUESTS,
|
|
|
|
|
num_tokens=NUM_TOKENS,
|
|
|
|
|
mm_hashes_list=[["hash_1"], ["hash_2"]],
|
|
|
|
|
mm_positions=[
|
|
|
|
|
[PlaceholderRange(offset=1, length=10)] for _ in range(NUM_REQUESTS)
|
|
|
|
|
],
|
|
|
|
|
max_tokens=MAX_TOKENS,
|
|
|
|
|
block_size=BLOCK_SIZE,
|
|
|
|
|
)
|
|
|
|
|
req_ids = []
|
|
|
|
|
req_to_index = {}
|
|
|
|
|
for i, request in enumerate(requests):
|
|
|
|
|
scheduler.add_request(request)
|
|
|
|
|
req_ids.append(request.request_id)
|
|
|
|
|
req_to_index[request.request_id] = i
|
|
|
|
|
|
|
|
|
|
# Setup MODEL_RUNNER_OUTPUT to be run in _step_until_done later
|
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
|
|
|
req_ids=req_ids,
|
|
|
|
|
req_id_to_index=req_to_index,
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[1000]] * len(req_ids),
|
2025-11-12 10:58:33 +08:00
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Just one request should be running.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
|
|
|
|
|
assert scheduled_tokens == NUM_TOKENS
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
|
|
|
|
|
# Should have called update_state_after_alloc for external load
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
|
|
|
scheduler.running[0], 0
|
|
|
|
|
)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# All memory should be freed, with one request waiting.
|
|
|
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
|
|
|
|
|
# Just one request should be running.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
|
|
|
|
|
assert scheduled_tokens == NUM_TOKENS
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
|
|
|
|
# update_state_after_alloc should be called for loading external cache
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
|
|
|
scheduler.running[0], 0
|
|
|
|
|
)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# All memory should be freed, with no requests waiting / running.
|
|
|
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
|
assert len(scheduler.running) == 0
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
|
|
|
|
cache_exist, use_kv_connector
|
|
|
|
|
):
|
|
|
|
|
"""Test that priority scheduling preempts lower priority requests
|
|
|
|
|
when out of KV cache space."""
|
|
|
|
|
# Create scheduler with very limited memory to force preemption
|
|
|
|
|
scheduler = create_scheduler_with_priority(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
max_num_seqs=2, # allow multiple requests
|
|
|
|
|
# kv connector should not effect test results
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
num_blocks=15, # can hold 244 tokens with 14 blocks (first block is null)
|
|
|
|
|
block_size=16, # standard block size
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Mock cache hit: Both cache exist in connector (at E->PD initially)
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
|
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create a request and schedule it (and to be preempted)
|
|
|
|
|
request_low = create_requests_with_priority(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
priorities=[1],
|
|
|
|
|
arrival_times=[0.0],
|
|
|
|
|
num_tokens=94,
|
|
|
|
|
mm_hashes_list=[["hash_low"]],
|
|
|
|
|
# NOTE: this test only preempt the last block.
|
|
|
|
|
# Setting mm_position at the last block can force to recompute encoding
|
|
|
|
|
mm_positions=[[PlaceholderRange(offset=82, length=10)]],
|
|
|
|
|
starting_idx=0,
|
|
|
|
|
)[0]
|
|
|
|
|
scheduler.add_request(request_low)
|
|
|
|
|
# 1st schedule
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request_low.request_id]
|
|
|
|
|
assert scheduled_tokens == 94
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
|
|
|
|
|
## Encoder-cache-specific checks:
|
|
|
|
|
# Encoder cache should contain mm items from request
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
|
|
|
|
|
|
|
|
|
|
# Verify update_state_after_alloc called (external load)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_low, 0)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata of request
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=request_low.mm_features
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
# Simulate model execution - 1st decode
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request_low.request_id],
|
|
|
|
|
req_id_to_index={request_low.request_id: 0},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100]],
|
2025-11-12 10:58:33 +08:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# Create a high priority request and schedule it
|
|
|
|
|
request_high = create_requests_with_priority(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
priorities=[0],
|
|
|
|
|
arrival_times=[1.0],
|
|
|
|
|
num_tokens=128,
|
|
|
|
|
mm_hashes_list=[["hash_high"]],
|
|
|
|
|
mm_positions=[[PlaceholderRange(offset=1, length=10)]],
|
|
|
|
|
max_tokens=2,
|
|
|
|
|
starting_idx=1,
|
|
|
|
|
)[0]
|
|
|
|
|
scheduler.add_request(request_high)
|
|
|
|
|
# 2nd schedule
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# KV cache should be full at this point
|
|
|
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 2
|
|
|
|
|
|
|
|
|
|
## Encoder-cache-specific checks:
|
|
|
|
|
# Encoder cache should contain mm items from request
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_high])
|
|
|
|
|
|
|
|
|
|
# Verify update_state_after_alloc called (external load)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_high, 0)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata of request
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=request_high.mm_features
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
# Simulate model execution - 2nd decode
|
|
|
|
|
requests = [request_low, request_high]
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100] for _ in requests],
|
2025-11-12 10:58:33 +08:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# 3rd schedule - - this should trigger preemption
|
|
|
|
|
# req_low needs 96 tokens = 6 blocks
|
|
|
|
|
# req_high needs 129 tokens = 9 blocks
|
|
|
|
|
# so doesn't fit in 14 blocks.
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Should have preempted req_low
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
|
|
|
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
|
|
|
|
|
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
|
|
|
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
|
|
|
|
|
## Encoder-cache-specific checks:
|
|
|
|
|
# request_high is in decode phase now
|
|
|
|
|
# ECConnector should carry no metadata
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
|
|
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
# Simulate model execution - 3rd decode, after req_low was preempted
|
|
|
|
|
requests = [request_low, request_high]
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[req.request_id for req in requests],
|
|
|
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100], [100, 200]],
|
2025-11-12 10:58:33 +08:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
# Finish the requests to make room for the preempted requests to resume
|
|
|
|
|
# req_high is finished after outputing 2 tokens
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
scheduler.finish_requests(
|
|
|
|
|
request_high.request_id, RequestStatus.FINISHED_LENGTH_CAPPED
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Set up to test different encoder cache exsistence scenario after preemption
|
|
|
|
|
# Order of getting encoder cache should be: local cache -> connector-> compute
|
|
|
|
|
# By default, the cache should still exist in local in this test case
|
|
|
|
|
if cache_exist != "local":
|
|
|
|
|
# Make local encoder cache empty
|
|
|
|
|
scheduler.encoder_cache_manager.cached = {}
|
|
|
|
|
|
|
|
|
|
if cache_exist == "connector_only":
|
|
|
|
|
# Cache exist in ec_connector
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
|
|
|
elif cache_exist == "no_where":
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
|
|
|
|
|
|
|
|
|
# 4th Schedule - this should trigger req_low resumption from waiting
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
scheduled_cached_reqs = output.scheduled_cached_reqs
|
|
|
|
|
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
|
|
|
|
|
|
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
|
|
|
assert scheduled_cached_reqs.num_reqs == 1
|
|
|
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
assert len(scheduler.running) == 1
|
|
|
|
|
|
|
|
|
|
# Preempted request resumed in scheduled_cached_reqs
|
|
|
|
|
assert len(resumed_from_preemption) == 1
|
|
|
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
|
|
|
|
|
assert resumed_from_preemption[0]
|
|
|
|
|
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
|
|
|
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
|
|
|
|
|
## Resumed tokens include 94 prompt tokens and 2 decoded tokens
|
|
|
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 96
|
|
|
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0][95] == 100
|
|
|
|
|
assert scheduler.running[0].request_id == request_low.request_id
|
|
|
|
|
assert request_high.request_id in output.finished_req_ids
|
|
|
|
|
|
|
|
|
|
## Encoder-cache-specific checks:
|
|
|
|
|
# mm_hash of request_low exists in cache after scheduling for all scenario
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
|
|
|
|
|
|
|
|
|
|
# Should only call update_state_after_alloc when loaded externally
|
|
|
|
|
if cache_exist == "connector_only":
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
|
|
|
request_low, 0
|
|
|
|
|
)
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=request_low.mm_features
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
|
|
|
|
|
# ECConnector should carry no metadata
|
|
|
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
|
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# Should only schedule encoder input when cache is not found anywhere
|
|
|
|
|
if cache_exist == "no_where":
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
requests=[request_low],
|
|
|
|
|
expected_encoder_inputs=[[0]],
|
|
|
|
|
expected_total_reqs=1,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
|
|
|
def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector):
|
|
|
|
|
"""
|
|
|
|
|
Scenario:
|
|
|
|
|
- Encoder cache size: 32
|
|
|
|
|
- Request A: 1 feature (12 tokens) → NOT cached remotely.
|
|
|
|
|
- Request B: 3 features (3 x 10 tokens) → ALL cached remotely.
|
|
|
|
|
|
|
|
|
|
Steps:
|
|
|
|
|
1. Schedule Request A (locally uses 12 tokens).
|
|
|
|
|
2. Schedule Request B (remote cache) - only schedule 1st and 2nd
|
|
|
|
|
3. Free A's cache, then schedule B again (continuation) - schedule 3rd image
|
|
|
|
|
"""
|
|
|
|
|
scheduler = create_scheduler(
|
|
|
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
|
enable_prefix_caching=True,
|
|
|
|
|
use_kv_connector=use_kv_connector,
|
|
|
|
|
block_size=16,
|
|
|
|
|
num_blocks=11, # Can hold 160 tokens (first block is null)
|
|
|
|
|
use_ec_connector=True,
|
|
|
|
|
ec_role="ec_consumer",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Limit the number of availiable slots of EncoderCacheManager
|
|
|
|
|
scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32)
|
|
|
|
|
|
|
|
|
|
# Create MM request1
|
|
|
|
|
NUM_TOKENS_1 = 50 # NOTE: includes mm tokens
|
|
|
|
|
NUM_ENCODER_TOKENS_1 = 12
|
|
|
|
|
mm_hashes_list_1 = [["hash1_1"]]
|
|
|
|
|
mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]]
|
|
|
|
|
|
|
|
|
|
request1 = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=NUM_TOKENS_1,
|
|
|
|
|
mm_hashes_list=mm_hashes_list_1,
|
|
|
|
|
mm_positions=mm_positions_1,
|
|
|
|
|
max_tokens=1, # For simplicity
|
|
|
|
|
req_ids=["req1"],
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# Create MM request1 with 3 MM items
|
|
|
|
|
NUM_TOKENS_2 = 40
|
|
|
|
|
NUM_ENCODER_TOKENS_2 = 10
|
|
|
|
|
mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]]
|
|
|
|
|
mm_positions_2 = [
|
|
|
|
|
[
|
|
|
|
|
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2),
|
|
|
|
|
]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
request2 = create_requests(
|
|
|
|
|
num_requests=1,
|
|
|
|
|
num_tokens=NUM_TOKENS_2,
|
|
|
|
|
mm_hashes_list=mm_hashes_list_2,
|
|
|
|
|
mm_positions=mm_positions_2,
|
|
|
|
|
max_tokens=10,
|
|
|
|
|
req_ids=["req2"],
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
|
|
|
|
|
scheduler.ec_connector.has_caches = Mock(
|
|
|
|
|
side_effect=lambda req: [True, True, True] if req == request2 else [False]
|
|
|
|
|
)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
|
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
scheduler.add_request(request1)
|
|
|
|
|
scheduler.add_request(request2)
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Now, since encoder cache manager can only store 32 tokens
|
|
|
|
|
# It should allocated mm item hash1_1, hash2_1 and hash2_2
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
|
|
|
|
|
assert scheduled_tokens == NUM_TOKENS_1
|
|
|
|
|
assert scheduler.get_num_unfinished_requests() == 2
|
|
|
|
|
|
|
|
|
|
# Encoder cache should contain mm item from request1
|
|
|
|
|
_assert_right_encoder_cache_allocated(
|
|
|
|
|
scheduler, hashes_to_check=["hash1_1", "hash2_1", "hash2_2"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# request2's 2nd mm item is the last call of update_state_after_alloc
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata of hash2_1 and hash2_2 ONLY
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Should schedule ONLY 1 encoder input
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
requests=[request1],
|
|
|
|
|
expected_encoder_inputs=[[0]], # index 0 of the mm item of request1
|
|
|
|
|
expected_total_reqs=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Simulate model execution 1 step
|
|
|
|
|
model_output = ModelRunnerOutput(
|
|
|
|
|
req_ids=[request1.request_id, request2.request_id],
|
|
|
|
|
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
|
2025-11-20 21:27:45 -08:00
|
|
|
sampled_token_ids=[[100], [121]],
|
2025-11-12 10:58:33 +08:00
|
|
|
# spec_token_ids=None,
|
|
|
|
|
logprobs=None,
|
|
|
|
|
prompt_logprobs_dict={},
|
|
|
|
|
pooler_output=[],
|
|
|
|
|
)
|
|
|
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
|
|
|
|
|
|
# request1 is finished after outputing 1 token
|
|
|
|
|
# Finish request
|
|
|
|
|
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
|
|
|
|
|
assert scheduler.get_num_unfinished_requests() == 1
|
|
|
|
|
|
|
|
|
|
# Schedule again; Now request1's encoder cache should be freed
|
|
|
|
|
# -> hash2_3 can be scheduled and allocated
|
|
|
|
|
output = scheduler.schedule()
|
|
|
|
|
|
|
|
|
|
# Check
|
|
|
|
|
# Should schedule all tokens
|
|
|
|
|
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
|
|
|
|
|
print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}")
|
|
|
|
|
print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}")
|
|
|
|
|
|
|
|
|
|
# Encoder cache should contain all mm items from request2
|
|
|
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
|
|
|
|
|
|
|
|
|
|
# request2's 3rd mm item is the ONLY call of update_state_after_alloc
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2)
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
|
|
|
|
|
|
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
|
|
|
|
|
|
# ECConnector should carry metadata for hash2_3 ONLY
|
|
|
|
|
_assert_right_ec_connector_metadata(
|
|
|
|
|
output, mm_features_list=[request2.mm_features[2]]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Should schedule no encoder input
|
|
|
|
|
_assert_right_encoder_inputs(
|
|
|
|
|
output,
|
|
|
|
|
expected_total_reqs=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end
|
|
|
|
|
# ==============================================================================
|