Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,9 +5,7 @@ from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
@@ -63,12 +61,12 @@ class VllmModel(Protocol[T_co]):
|
||||
) -> T_co: ...
|
||||
|
||||
|
||||
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
||||
def _check_vllm_model_init(model: type[object] | object) -> bool:
|
||||
model_init = model.__init__
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool:
|
||||
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if not callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
@@ -80,7 +78,7 @@ def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -
|
||||
return True
|
||||
|
||||
|
||||
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||
def _check_vllm_model_forward(model: type[object] | object) -> bool:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
return False
|
||||
@@ -108,8 +106,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
|
||||
|
||||
|
||||
def is_vllm_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
|
||||
return (
|
||||
_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
@@ -124,7 +122,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: T,
|
||||
) -> Optional[T]:
|
||||
) -> T | None:
|
||||
"""Return `None` if TP rank > 0."""
|
||||
...
|
||||
|
||||
@@ -140,10 +138,8 @@ def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration
|
||||
|
||||
|
||||
def is_text_generation_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[
|
||||
TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]
|
||||
]:
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]:
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
@@ -190,8 +186,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
|
||||
|
||||
|
||||
def is_pooling_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]:
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
@@ -211,5 +207,5 @@ def default_pooling_type(pooling_type: str):
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
||||
def get_default_pooling_type(model: type[object] | object) -> str:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
|
||||
Reference in New Issue
Block a user