[Core] feat: Implement Priority Scheduling in V1 Engine (#19057)
Signed-off-by: amit <amit.man@gmail.com> Co-authored-by: Roger Wang <Rogerw0108@gmail.com>
This commit is contained in:
@@ -1150,7 +1150,6 @@ def test_kv_connector_handles_preemption():
|
||||
assert len(scheduler.running) == 1
|
||||
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
# All memory should be freed since nothing is running.
|
||||
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||
== NUM_BLOCKS - 1
|
||||
@@ -1265,3 +1264,592 @@ def test_memory_leak():
|
||||
|
||||
# Confirm no memory leak.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def create_scheduler_with_priority(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
disable_chunked_mm_input: bool = False,
|
||||
use_kv_connector: bool = False,
|
||||
num_blocks: int = 10000,
|
||||
block_size: int = 16,
|
||||
max_model_len: Optional[int] = None,
|
||||
num_speculative_tokens: Optional[int] = None,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler with priority policy enabled.
|
||||
|
||||
Args:
|
||||
model: model under test
|
||||
max_num_seqs: max sequences to schedule
|
||||
max_num_batch_tokens: max num tokens to batch
|
||||
enable_prefix_caching: optionally force APC config
|
||||
(True/False) or use default
|
||||
(None)
|
||||
|
||||
Returns:
|
||||
{class}`Scheduler` instance with priority scheduling
|
||||
'''
|
||||
if max_model_len is None:
|
||||
max_model_len = max_num_batched_tokens
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
enable_chunked_prefill=True,
|
||||
policy="priority", # Enable priority scheduling
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||
'enable_prefix_caching': enable_prefix_caching
|
||||
})
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
) if use_kv_connector else None
|
||||
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
if num_speculative_tokens is not None:
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
def create_requests_with_priority(
|
||||
num_requests: int,
|
||||
priorities: list[int],
|
||||
arrival_times: Optional[list[float]] = None,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
"""Create requests with specified priorities and arrival times."""
|
||||
assert len(priorities) == num_requests
|
||||
if arrival_times is not None:
|
||||
assert len(arrival_times) == num_requests
|
||||
else:
|
||||
arrival_times = [float(i) for i in range(num_requests)]
|
||||
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=arrival_times[i],
|
||||
priority=priorities[i],
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
|
||||
def test_priority_scheduling_basic_ordering():
|
||||
"""Test that requests are scheduled in priority order
|
||||
(lower value = higher priority)."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
# Create requests with different priorities
|
||||
# Priority 0 (highest), 1, 2 (lowest)
|
||||
priorities = [2, 0, 1] # Add in non-priority order
|
||||
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
|
||||
requests = create_requests_with_priority(num_requests=3,
|
||||
priorities=priorities,
|
||||
arrival_times=arrival_times)
|
||||
|
||||
# Add requests in non-priority order
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule and verify priority order
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Should schedule all requests since they fit in budget
|
||||
assert len(output.scheduled_new_reqs) == 3
|
||||
|
||||
# Verify they are scheduled in priority order:
|
||||
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
|
||||
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||
assert scheduled_req_ids == ["1", "2", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_arrival_time_tiebreaker():
|
||||
"""Test that arrival time is used
|
||||
as tiebreaker when priorities are equal."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
# Create requests with same priority but different arrival times
|
||||
priorities = [1, 1, 1] # All same priority
|
||||
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
|
||||
requests = create_requests_with_priority(num_requests=3,
|
||||
priorities=priorities,
|
||||
arrival_times=arrival_times)
|
||||
|
||||
# Add requests in non-arrival order
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule and verify arrival time order
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Should schedule all requests since they fit in budget
|
||||
assert len(output.scheduled_new_reqs) == 3
|
||||
|
||||
# Verify they are scheduled in arrival time order:
|
||||
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
|
||||
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||
assert scheduled_req_ids == ["1", "2", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_mixed_priority_and_arrival():
|
||||
"""Test priority scheduling with mixed priorities and arrival times."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
# Create requests with mixed priorities and arrival times
|
||||
priorities = [2, 1, 1, 0] # Mixed priorities
|
||||
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
|
||||
requests = create_requests_with_priority(num_requests=4,
|
||||
priorities=priorities,
|
||||
arrival_times=arrival_times)
|
||||
|
||||
# Add requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule and verify order
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Should schedule all requests since they fit in budget
|
||||
assert len(output.scheduled_new_reqs) == 4
|
||||
|
||||
# Expected order:
|
||||
# 1. req_3 (priority 0, arrival 4.0)
|
||||
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
|
||||
# 3. req_1 (priority 1, arrival 3.0)
|
||||
# 4. req_0 (priority 2, arrival 1.0)
|
||||
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||
assert scheduled_req_ids == ["3", "2", "1", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_preemption():
|
||||
"""Test that priority scheduling preempts
|
||||
lower priority requests when memory is constrained."""
|
||||
# Create scheduler with very limited memory to force preemption
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=3, # Allow multiple requests
|
||||
max_num_batched_tokens=200,
|
||||
num_blocks=6, # Very limited blocks to force memory pressure
|
||||
block_size=16, # Standard block size
|
||||
)
|
||||
|
||||
# Create initial low-priority requests that will consume most memory
|
||||
low_priority_requests = create_requests_with_priority(
|
||||
num_requests=2,
|
||||
priorities=[5, 5], # Low priority
|
||||
arrival_times=[1.0, 2.0],
|
||||
num_tokens=30 # Large enough to consume significant memory
|
||||
)
|
||||
|
||||
# Add and schedule low priority requests
|
||||
for request in low_priority_requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 2
|
||||
|
||||
# Simulate model execution to move requests to running state
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in low_priority_requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# Verify both requests are running
|
||||
assert len(scheduler.running) == 2
|
||||
|
||||
# Now add a high-priority request that requires memory allocation
|
||||
# This should trigger preemption due to memory constraints
|
||||
high_priority_request = create_requests_with_priority(
|
||||
num_requests=1,
|
||||
priorities=[0], # High priority
|
||||
arrival_times=[3.0],
|
||||
num_tokens=30 # Large enough to require significant memory
|
||||
)[0]
|
||||
|
||||
scheduler.add_request(high_priority_request)
|
||||
|
||||
# Schedule again - this should trigger
|
||||
# preemption when trying to allocate memory
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Due to the scheduler's design, if preemption happens
|
||||
# during running request scheduling,
|
||||
# waiting requests won't be scheduled in the same step
|
||||
# Let's check if preemption occurred by looking at the waiting queue
|
||||
|
||||
# If preemption happened, we should see requests in the
|
||||
# waiting queue
|
||||
if len(scheduler.waiting) > 1: # high priority + preempted request
|
||||
# Preemption occurred - verify the high priority request
|
||||
# gets scheduled next
|
||||
output2 = scheduler.schedule()
|
||||
assert len(output2.scheduled_new_reqs) == 1
|
||||
# High priority request
|
||||
assert output2.scheduled_new_reqs[0].req_id == "0"
|
||||
else:
|
||||
# No preemption needed - all requests fit
|
||||
# This is also valid behavior if memory allows
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
# High priority request
|
||||
assert output.scheduled_new_reqs[0].req_id == "0"
|
||||
|
||||
|
||||
def test_priority_scheduling_no_preemption_when_space_available():
|
||||
"""Test that preemption doesn't happen
|
||||
when there's space for new requests."""
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=3, # Allow 3 concurrent requests
|
||||
max_num_batched_tokens=200, # Sufficient token budget
|
||||
)
|
||||
|
||||
# Add two low-priority running requests
|
||||
low_priority_requests = create_requests_with_priority(
|
||||
num_requests=2,
|
||||
priorities=[5, 5],
|
||||
arrival_times=[1.0, 2.0],
|
||||
num_tokens=30)
|
||||
|
||||
for request in low_priority_requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in low_priority_requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# Add high-priority request
|
||||
high_priority_request = create_requests_with_priority(num_requests=1,
|
||||
priorities=[0],
|
||||
arrival_times=[3.0],
|
||||
num_tokens=30)[0]
|
||||
|
||||
scheduler.add_request(high_priority_request)
|
||||
|
||||
# Schedule - should not preempt since there's space
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Should schedule the new request without preemption
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert len(scheduler.running) == 3 # All three requests running
|
||||
assert len(scheduler.waiting) == 0 # No requests waiting
|
||||
|
||||
|
||||
def test_priority_scheduling_preemption_victim_selection():
|
||||
"""Test that the correct victim is selected for
|
||||
preemption based on priority and arrival time."""
|
||||
# This test verifies the priority-based victim selection logic
|
||||
# by checking the waiting queue order after adding requests with different
|
||||
# priorities
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=1, # Force sequential processing to test priority order
|
||||
)
|
||||
|
||||
# Create requests with different priorities
|
||||
requests = create_requests_with_priority(
|
||||
num_requests=3,
|
||||
priorities=[3, 2, 0], # Different priorities: low, medium, high
|
||||
arrival_times=[1.0, 2.0, 3.0],
|
||||
num_tokens=10)
|
||||
|
||||
# Add all requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule - should only schedule the highest priority request
|
||||
# (req_2, priority 0)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
|
||||
|
||||
# Verify the waiting queue has the remaining requests in priority order
|
||||
assert len(scheduler.waiting) == 2
|
||||
|
||||
# Extract waiting requests and verify priority order
|
||||
waiting_requests = list(scheduler.waiting)
|
||||
|
||||
waiting_priorities = [req.priority for req in waiting_requests]
|
||||
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||
|
||||
# Should be req_1 (priority 2) then req_0 (priority 3)
|
||||
assert waiting_priorities == [2, 3]
|
||||
assert waiting_req_ids == ["1", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_equal_priority_preemption():
|
||||
"""Test arrival time tiebreaker when requests have equal priority."""
|
||||
# This test verifies that arrival time is used as a tiebreaker for equal
|
||||
# priorities
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=1, # Force sequential processing
|
||||
)
|
||||
|
||||
# Create requests with same priority but different arrival times
|
||||
requests = create_requests_with_priority(
|
||||
num_requests=3,
|
||||
priorities=[2, 2, 2], # Same priority
|
||||
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
|
||||
num_tokens=10)
|
||||
|
||||
# Add all requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule - should schedule the request with earliest arrival time
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
|
||||
|
||||
# Verify the waiting queue has remaining requests in arrival time order
|
||||
assert len(scheduler.waiting) == 2
|
||||
|
||||
# Extract waiting requests and verify arrival time order
|
||||
waiting_requests = list(scheduler.waiting)
|
||||
|
||||
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
|
||||
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||
|
||||
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
|
||||
assert waiting_arrival_times == [2.0, 3.0]
|
||||
assert waiting_req_ids == ["2", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_waiting_queue_order():
|
||||
"""Test that the waiting queue maintains priority order."""
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=1, # Only one request can run at a time
|
||||
)
|
||||
|
||||
# Create multiple requests with different priorities
|
||||
requests = create_requests_with_priority(
|
||||
num_requests=4,
|
||||
priorities=[3, 1, 2, 0], # Mixed priorities
|
||||
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
||||
num_tokens=10)
|
||||
|
||||
# Add all requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule - should only schedule the highest priority request
|
||||
# (req_3, priority 0)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert output.scheduled_new_reqs[0].req_id == "3"
|
||||
|
||||
# Verify waiting queue has remaining requests in priority order
|
||||
assert len(scheduler.waiting) == 3
|
||||
|
||||
# Extract requests from waiting queue
|
||||
# (it's a heap, so we need to pop to see order)
|
||||
waiting_requests = list(scheduler.waiting)
|
||||
|
||||
waiting_priorities = [req.priority for req in waiting_requests]
|
||||
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||
|
||||
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
|
||||
assert waiting_req_ids == ["1", "2", "0"]
|
||||
assert waiting_priorities == [1, 2, 3]
|
||||
|
||||
|
||||
def test_priority_scheduling_fcfs_fallback():
|
||||
"""Test that FCFS behavior is maintained when all
|
||||
requests have same priority."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
# Create requests with same priority but different arrival times
|
||||
priorities = [1, 1, 1, 1] # All same priority
|
||||
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
|
||||
requests = create_requests_with_priority(num_requests=4,
|
||||
priorities=priorities,
|
||||
arrival_times=arrival_times)
|
||||
|
||||
# Add requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Should schedule all requests in arrival time order
|
||||
assert len(output.scheduled_new_reqs) == 4
|
||||
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||
|
||||
# Expected order by arrival time:
|
||||
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
|
||||
assert scheduled_req_ids == ["1", "3", "2", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_with_limited_slots():
|
||||
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=2, # Only allow 2 concurrent requests
|
||||
max_num_batched_tokens=1000, # Plenty of token budget
|
||||
)
|
||||
|
||||
# Create requests with different priorities
|
||||
requests = create_requests_with_priority(
|
||||
num_requests=4,
|
||||
priorities=[3, 1, 2, 0], # Mixed priorities
|
||||
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
||||
num_tokens=10)
|
||||
|
||||
# Add all requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule - should only schedule the 2 highest priority requests
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 2
|
||||
|
||||
# Should schedule req_3 (priority 0) and req_1 (priority 1)
|
||||
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||
assert "3" in scheduled_req_ids # Priority 0
|
||||
assert "1" in scheduled_req_ids # Priority 1
|
||||
|
||||
# Remaining requests should be in waiting queue in priority order
|
||||
assert len(scheduler.waiting) == 2
|
||||
|
||||
# Extract waiting requests and verify order
|
||||
waiting_requests = list(scheduler.waiting)
|
||||
waiting_priorities = [req.priority for req in waiting_requests]
|
||||
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||
|
||||
# Should be req_2 (priority 2) then req_0 (priority 3)
|
||||
assert waiting_priorities == [2, 3]
|
||||
assert waiting_req_ids == ["2", "0"]
|
||||
|
||||
|
||||
def test_priority_scheduling_heap_property():
|
||||
"""Test that the waiting queue maintains heap
|
||||
property for priority scheduling."""
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=1, # Only one request can run at a time
|
||||
)
|
||||
|
||||
# Add requests in random priority order
|
||||
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
|
||||
arrival_times = [float(i) for i in range(len(priorities))]
|
||||
requests = create_requests_with_priority(num_requests=len(priorities),
|
||||
priorities=priorities,
|
||||
arrival_times=arrival_times,
|
||||
num_tokens=10)
|
||||
|
||||
# Add all requests
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Schedule one request at a time and verify priority order
|
||||
scheduled_priorities = []
|
||||
|
||||
while scheduler.waiting:
|
||||
output = scheduler.schedule()
|
||||
if output.scheduled_new_reqs:
|
||||
req = output.scheduled_new_reqs[0]
|
||||
scheduled_priorities.append(requests[int(req.req_id)].priority)
|
||||
|
||||
# Simulate completion to make room for next request
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.req_id],
|
||||
req_id_to_index={req.req_id: 0},
|
||||
sampled_token_ids=[[100]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# Finish the request to make room for the next one
|
||||
scheduler.finish_requests(req.req_id,
|
||||
RequestStatus.FINISHED_STOPPED)
|
||||
|
||||
# Verify requests were scheduled in priority order (lowest value first)
|
||||
expected_priorities = sorted(priorities)
|
||||
assert scheduled_priorities == expected_priorities
|
||||
|
||||
Reference in New Issue
Block a user