Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
|
||||
from typing import Any, Callable, TypedDict, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -68,7 +68,7 @@ _SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
|
||||
_PromptMultiModalInput = list[_M] | list[list[_M]]
|
||||
|
||||
PromptImageInput = _PromptMultiModalInput[Image.Image]
|
||||
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
|
||||
@@ -267,7 +267,7 @@ class HfRunner:
|
||||
|
||||
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
def wrap_device(self, x: _T, device: str | None = None) -> _T:
|
||||
if x is None or isinstance(x, (bool,)):
|
||||
return x
|
||||
|
||||
@@ -287,14 +287,14 @@ class HfRunner:
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
trust_remote_code: bool = True,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
# Set this to avoid hanging issue
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
default_torch_num_threads: int | None = None,
|
||||
) -> None:
|
||||
init_ctx = (
|
||||
nullcontext()
|
||||
@@ -319,7 +319,7 @@ class HfRunner:
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
trust_remote_code: bool = True,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
@@ -406,11 +406,11 @@ class HfRunner:
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: Union[list[str], list[list[int]]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]:
|
||||
prompts: list[str] | list[list[int]],
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
@@ -420,9 +420,7 @@ class HfRunner:
|
||||
if audios is not None:
|
||||
assert len(prompts) == len(audios)
|
||||
|
||||
all_inputs: list[
|
||||
Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]
|
||||
] = []
|
||||
all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
if isinstance(prompt, str):
|
||||
processor_kwargs: dict[str, Any] = {
|
||||
@@ -494,10 +492,10 @@ class HfRunner:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[list[str], list[list[int]]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
prompts: list[str] | list[list[int]],
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
all_inputs = self.get_inputs(
|
||||
@@ -522,11 +520,11 @@ class HfRunner:
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: Union[list[str], list[list[int]]],
|
||||
prompts: list[str] | list[list[int]],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[int], str]]:
|
||||
outputs = self.generate(
|
||||
@@ -546,9 +544,9 @@ class HfRunner:
|
||||
prompts: list[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
@@ -574,9 +572,9 @@ class HfRunner:
|
||||
self,
|
||||
prompts: list[str],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[torch.Tensor]]:
|
||||
all_inputs = self.get_inputs(
|
||||
@@ -624,7 +622,7 @@ class HfRunner:
|
||||
def _hidden_states_to_logprobs(
|
||||
self,
|
||||
hidden_states: tuple[tuple[torch.Tensor, ...], ...],
|
||||
num_logprobs: Optional[int],
|
||||
num_logprobs: int | None,
|
||||
) -> tuple[list[dict[int, float]], int]:
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
|
||||
output_len = len(hidden_states)
|
||||
@@ -652,10 +650,10 @@ class HfRunner:
|
||||
self,
|
||||
prompts: list[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: Optional[int],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
num_logprobs: int | None,
|
||||
images: PromptImageInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs]:
|
||||
all_inputs = self.get_inputs(
|
||||
@@ -734,20 +732,20 @@ class VllmRunner:
|
||||
model_name: str,
|
||||
runner: RunnerOption = "auto",
|
||||
convert: ConvertOption = "auto",
|
||||
tokenizer_name: Optional[str] = None,
|
||||
tokenizer_name: str | None = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
seed: Optional[int] = 0,
|
||||
max_model_len: Optional[int] = 1024,
|
||||
seed: int | None = 0,
|
||||
max_model_len: int | None = 1024,
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16 if not torch.xpu.is_available() else 64,
|
||||
enable_chunked_prefill: Optional[bool] = False,
|
||||
enable_chunked_prefill: bool | None = False,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: Optional[bool] = False,
|
||||
enforce_eager: bool | None = False,
|
||||
# Set this to avoid hanging issue
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
default_torch_num_threads: int | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_ctx = (
|
||||
@@ -785,10 +783,10 @@ class VllmRunner:
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if any(
|
||||
x is not None and len(x) != len(prompts) for x in [images, videos, audios]
|
||||
@@ -824,11 +822,11 @@ class VllmRunner:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
@@ -871,11 +869,11 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: list[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.generate(
|
||||
@@ -894,11 +892,11 @@ class VllmRunner:
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
|
||||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
@@ -916,15 +914,15 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: list[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
num_logprobs: int | None,
|
||||
num_prompt_logprobs: int | None = None,
|
||||
images: PromptImageInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
@@ -957,7 +955,7 @@ class VllmRunner:
|
||||
perplexities = []
|
||||
for output in outputs:
|
||||
output = cast(TokensTextLogprobsPromptLogprobs, output)
|
||||
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
|
||||
token_datas = cast(list[dict[int, Logprob] | None], output[3])
|
||||
assert token_datas[0] is None
|
||||
token_log_probs = []
|
||||
for token_data in token_datas[1:]:
|
||||
@@ -976,10 +974,10 @@ class VllmRunner:
|
||||
prompts: list[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
concurrency_limit: Optional[int] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
concurrency_limit: int | None = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
@@ -1002,9 +1000,9 @@ class VllmRunner:
|
||||
def embed(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
images: PromptImageInput | None = None,
|
||||
videos: PromptVideoInput | None = None,
|
||||
audios: PromptAudioInput | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[list[float]]:
|
||||
@@ -1023,8 +1021,8 @@ class VllmRunner:
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[str, list[str]],
|
||||
text_2: Union[str, list[str]],
|
||||
text_1: list[str] | str,
|
||||
text_2: list[str] | str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[float]:
|
||||
@@ -1226,8 +1224,8 @@ def _find_free_port() -> int:
|
||||
class LocalAssetServer:
|
||||
address: str
|
||||
port: int
|
||||
server: Optional[http.server.ThreadingHTTPServer]
|
||||
thread: Optional[threading.Thread]
|
||||
server: http.server.ThreadingHTTPServer | None
|
||||
thread: threading.Thread | None
|
||||
|
||||
def __init__(self, address: str = "127.0.0.1") -> None:
|
||||
self.address = address
|
||||
|
||||
Reference in New Issue
Block a user