Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -5,9 +5,7 @@ from typing import (
Any,
ClassVar,
Literal,
Optional,
Protocol,
Union,
overload,
runtime_checkable,
)
@@ -63,12 +61,12 @@ class VllmModel(Protocol[T_co]):
) -> T_co: ...
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
def _check_vllm_model_init(model: type[object] | object) -> bool:
model_init = model.__init__
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool:
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if not callable(model_get_input_embeddings):
logger.warning(
@@ -80,7 +78,7 @@ def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -
return True
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
def _check_vllm_model_forward(model: type[object] | object) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
@@ -108,8 +106,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
def is_vllm_model(
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
model: type[object] | object,
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
return (
_check_vllm_model_init(model)
and _check_vllm_model_get_input_embeddings(model)
@@ -124,7 +122,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
def compute_logits(
self,
hidden_states: T,
) -> Optional[T]:
) -> T | None:
"""Return `None` if TP rank > 0."""
...
@@ -140,10 +138,8 @@ def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration
def is_text_generation_model(
model: Union[type[object], object],
) -> Union[
TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]
]:
model: type[object] | object,
) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]:
if not is_vllm_model(model):
return False
@@ -190,8 +186,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
def is_pooling_model(
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
model: type[object] | object,
) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]:
if not is_vllm_model(model):
return False
@@ -211,5 +207,5 @@ def default_pooling_type(pooling_type: str):
return func
def get_default_pooling_type(model: Union[type[object], object]) -> str:
def get_default_pooling_type(model: type[object] | object) -> str:
return getattr(model, "default_pooling_type", "LAST")