[core][scheduler] simplify and improve scheduler (#6867)
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Deque, List, Set, Tuple
|
||||
from typing import List, Set, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
|
||||
@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
|
||||
"""
|
||||
scheduler = initialize_scheduler(max_model_len=30)
|
||||
_, seq_group = create_dummy_prompt("0", prompt_length=60)
|
||||
waiting = deque([seq_group])
|
||||
scheduler.add_seq_group(seq_group)
|
||||
budget = create_token_budget()
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 1
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
|
||||
Test token budget respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(token_budget=0)
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
# 0 token budget == nothing is scheduled.
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
|
||||
|
||||
# 60 token budget == 1 request scheduled.
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 60
|
||||
@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
|
||||
|
||||
# Test when current_batched_tokens respected.
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
add_token_budget(budget, 30, 0)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
# Cannot schedule a prompt that doesn't fit the budget.
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 30
|
||||
@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
|
||||
assert len(remaining_waiting) == 1
|
||||
budget = create_token_budget(token_budget=90)
|
||||
add_token_budget(budget, 30, 0)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 90
|
||||
assert budget.num_curr_seqs == 1
|
||||
@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
|
||||
Test max seq respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Verify curr_num_seqs respected.
|
||||
waiting = deque()
|
||||
scheduler.waiting = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
add_token_budget(budget, 0, 2)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
|
||||
"""
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(token_budget=120)
|
||||
curr_loras: Set[int] = set()
|
||||
for i in range(2):
|
||||
@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_path="abc"))
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
# Add two more requests to verify lora is prioritized.
|
||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||
# In the first iteration, index 0, 2 is scheduled.
|
||||
@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
|
||||
# prioritized. Verify that.
|
||||
for i in range(2, 4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
# Schedule 2 requests (0 and 2)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, curr_loras)
|
||||
output = scheduler._schedule_prefills(budget, curr_loras)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
|
||||
# Reset curr_loras so that it can be scheduled.
|
||||
curr_loras = set()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
remaining_waiting, budget, curr_loras)
|
||||
output = scheduler._schedule_prefills(budget, curr_loras)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.seq_groups) == 1
|
||||
assert output.seq_groups[0].seq_group.request_id == "1"
|
||||
assert len(remaining_waiting) == 1
|
||||
@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
|
||||
Test sequence cannot be scheduled due to block manager has no capacity.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
|
||||
remainig_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remainig_waiting) == 3
|
||||
assert len(remaining_waiting) == 3
|
||||
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 3
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
|
||||
Test decodes cannot be scheduled and preempted.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
|
||||
# 1 cannot be scheduled, and the lowest priority (request 2)
|
||||
# should be preempted. 1 will also be preempted.
|
||||
budget = create_token_budget()
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remainig_running = scheduler.running
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
running.append(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
|
||||
expected_swap_mapping = [("5", "7")]
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remainig_running = scheduler.running
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = [(2, 3)]
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remaining_running = scheduler.running
|
||||
assert len(remaining_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
|
||||
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget(token_budget=1)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
|
||||
# Verify num_batched_tokens are respected.
|
||||
budget = create_token_budget(token_budget=1)
|
||||
add_token_budget(budget, 1, 0)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 0
|
||||
@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
|
||||
|
||||
def test_schedule_swapped_max_seqs():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(4):
|
||||
@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_curr_seqs are respected.
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
|
||||
def test_schedule_swapped_max_loras():
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras: Set[int] = set()
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(2):
|
||||
@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 1
|
||||
@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
|
||||
|
||||
def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
|
||||
def test_infeasible_swap():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
@@ -815,15 +790,15 @@ def test_infeasible_swap():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.infeasible_seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
@@ -834,23 +809,21 @@ def test_infeasible_swap():
|
||||
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = [(2, 3)]
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
Reference in New Issue
Block a user