[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -1,6 +1,6 @@
import time
from collections import deque
from typing import List
from typing import Deque, List, Set, Tuple
from unittest.mock import MagicMock
import pytest # noqa
@@ -65,7 +65,7 @@ def test_scheduler_abort_seq_group():
# Add multiple seq groups to scheduler.
num_seq_group = 4
request_ids = set()
request_ids: Set[str] = set()
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), block_size)
scheduler.add_seq_group(seq_group)
@@ -347,7 +347,7 @@ def test_prefill_schedule_max_prompt_len():
Test prompt longer than max_prompt_len is aborted.
"""
scheduler = initialize_scheduler(max_model_len=30)
_, seq_group = create_dummy_prompt(0, prompt_length=60)
_, seq_group = create_dummy_prompt("0", prompt_length=60)
waiting = deque([seq_group])
budget = create_token_budget()
remaining_waiting, output = scheduler._schedule_prefills(
@@ -364,7 +364,7 @@ def test_prefill_schedule_token_budget():
Test token budget respected.
"""
scheduler = initialize_scheduler()
waiting = deque()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=0)
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
@@ -419,7 +419,7 @@ def test_prefill_schedule_max_seqs():
Test max seq respected.
"""
scheduler = initialize_scheduler()
waiting = deque()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(max_num_seqs=2)
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
@@ -453,9 +453,9 @@ def test_prefill_schedule_max_lora():
"""
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config)
waiting = deque()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=120)
curr_loras = set()
curr_loras: Set[int] = set()
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
@@ -499,7 +499,7 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity.
"""
scheduler = initialize_scheduler()
waiting = deque()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
@@ -536,7 +536,7 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted.
"""
scheduler = initialize_scheduler()
running = deque()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
for i in range(3):
@@ -577,7 +577,7 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks
"""
scheduler = initialize_scheduler()
running = deque()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
budget = create_token_budget()
@@ -628,7 +628,7 @@ def test_schedule_decode_blocks_to_copy_update():
"""
scheduler = initialize_scheduler()
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
running = deque()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
scheduler._allocate_and_set_running(seq_group)
@@ -656,10 +656,10 @@ def test_schedule_decode_blocks_to_copy_update():
def test_schedule_swapped_simple():
scheduler = initialize_scheduler()
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = []
blocks_to_swap_out: List[Tuple[int, int]] = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
@@ -683,10 +683,10 @@ def test_schedule_swapped_simple():
def test_schedule_swapped_max_token_budget():
scheduler = initialize_scheduler()
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = []
blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@@ -717,10 +717,10 @@ def test_schedule_swapped_max_token_budget():
def test_schedule_swapped_max_seqs():
scheduler = initialize_scheduler()
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = []
blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group)
@@ -750,10 +750,10 @@ def test_schedule_swapped_max_seqs():
def test_schedule_swapped_max_loras():
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config)
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set()
blocks_to_swap_out = []
curr_loras: Set[int] = set()
blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
@@ -779,10 +779,10 @@ def test_schedule_swapped_max_loras():
def test_schedule_swapped_cannot_swap_in():
scheduler = initialize_scheduler()
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = []
blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@@ -806,10 +806,10 @@ def test_schedule_swapped_cannot_swap_in():
def test_infeasible_swap():
scheduler = initialize_scheduler()
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = []
blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@@ -834,13 +834,13 @@ def test_infeasible_swap():
def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler()
swapped = deque()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = []
blocks_to_swap_out: List[Tuple[int, int]] = []
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)