[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import List
|
||||
from typing import Deque, List, Set, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
@@ -65,7 +65,7 @@ def test_scheduler_abort_seq_group():
|
||||
|
||||
# Add multiple seq groups to scheduler.
|
||||
num_seq_group = 4
|
||||
request_ids = set()
|
||||
request_ids: Set[str] = set()
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
@@ -347,7 +347,7 @@ def test_prefill_schedule_max_prompt_len():
|
||||
Test prompt longer than max_prompt_len is aborted.
|
||||
"""
|
||||
scheduler = initialize_scheduler(max_model_len=30)
|
||||
_, seq_group = create_dummy_prompt(0, prompt_length=60)
|
||||
_, seq_group = create_dummy_prompt("0", prompt_length=60)
|
||||
waiting = deque([seq_group])
|
||||
budget = create_token_budget()
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
@@ -364,7 +364,7 @@ def test_prefill_schedule_token_budget():
|
||||
Test token budget respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(token_budget=0)
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
@@ -419,7 +419,7 @@ def test_prefill_schedule_max_seqs():
|
||||
Test max seq respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
@@ -453,9 +453,9 @@ def test_prefill_schedule_max_lora():
|
||||
"""
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
waiting = deque()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(token_budget=120)
|
||||
curr_loras = set()
|
||||
curr_loras: Set[int] = set()
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
@@ -499,7 +499,7 @@ def test_prefill_schedule_no_block_manager_capacity():
|
||||
Test sequence cannot be scheduled due to block manager has no capacity.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
@@ -536,7 +536,7 @@ def test_decode_schedule_preempted():
|
||||
Test decodes cannot be scheduled and preempted.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running = deque()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
for i in range(3):
|
||||
@@ -577,7 +577,7 @@ def test_decode_swap_beam_search():
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running = deque()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
budget = create_token_budget()
|
||||
@@ -628,7 +628,7 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
running = deque()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -656,10 +656,10 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = []
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
@@ -683,10 +683,10 @@ def test_schedule_swapped_simple():
|
||||
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = []
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -717,10 +717,10 @@ def test_schedule_swapped_max_token_budget():
|
||||
|
||||
def test_schedule_swapped_max_seqs():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = []
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -750,10 +750,10 @@ def test_schedule_swapped_max_seqs():
|
||||
def test_schedule_swapped_max_loras():
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = set()
|
||||
blocks_to_swap_out = []
|
||||
curr_loras: Set[int] = set()
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
@@ -779,10 +779,10 @@ def test_schedule_swapped_max_loras():
|
||||
|
||||
def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = []
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -806,10 +806,10 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
|
||||
def test_infeasible_swap():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = []
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -834,13 +834,13 @@ def test_infeasible_swap():
|
||||
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out = []
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user