Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]], float]:
|
||||
tokens: List[str] = []
|
||||
token_ids: List[List[int]] = []
|
||||
sampling_params) -> tuple[list[str], list[list[int]], float]:
|
||||
tokens: list[str] = []
|
||||
token_ids: list[list[int]] = []
|
||||
acceptance_rate: float = -1.0
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int):
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
expected_output: List[List[int]] = [
|
||||
expected_output: list[list[int]] = [
|
||||
[],
|
||||
]
|
||||
for i in range(proposal_token_ids.shape[0]):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -221,7 +220,7 @@ def test_same_output_for_multi_step():
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
@@ -243,15 +242,15 @@ def test_same_output_for_multi_step():
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs: List[List[Dict[int,
|
||||
multi_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
single_step_output_logprobs: List[List[Dict[int,
|
||||
single_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
|
||||
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
@@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output():
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
@@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -15,7 +14,7 @@ from vllm.worker.worker import Worker
|
||||
from .utils import create_batch, create_worker
|
||||
|
||||
|
||||
def create_proposal(propose_lens: List[int], vocab_size: int,
|
||||
def create_proposal(propose_lens: list[int], vocab_size: int,
|
||||
device: str) -> SpeculativeProposals:
|
||||
batch_size = len(propose_lens)
|
||||
max_propose_len = max(propose_lens)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Set
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts: List[List[int]] = []
|
||||
seen_contexts: list[list[int]] = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
@@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model(
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts: List[List[int]] = []
|
||||
expected_seen_contexts: list[list[int]] = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
@@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
||||
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
||||
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
@@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens():
|
||||
size=(batch_size, (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_id in seq_group_metadata.seq_data:
|
||||
expected_request_id_seq_ids_mapping[
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from itertools import count
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import TypeVar, Union
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
@@ -44,7 +43,7 @@ def mock_worker(cls=None,
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: List[CacheEngine]):
|
||||
def zero_kv_cache(cache_engine: list[CacheEngine]):
|
||||
assert cache_engine[0].gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
|
||||
key_blocks.zero_()
|
||||
@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
prompts: list[list[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_prompt_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
final_prompt_lens: list[int],
|
||||
continuations: Optional[list[list[int]]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: List[int],
|
||||
prompt: list[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
|
||||
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: List[Dict[int, Logprob]],
|
||||
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
||||
actual_logprobs: list[dict[int, Logprob]],
|
||||
expected_logprobs: list[dict[int, Logprob]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
@@ -202,7 +201,7 @@ def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
@@ -231,9 +230,9 @@ def create_sampler_output_list(
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, List[int]] = 10,
|
||||
prompt_len: Union[int, list[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
|
||||
Reference in New Issue
Block a user