Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
runtime_checkable)
import torch
@@ -20,7 +20,7 @@ logger = init_logger(__name__)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
# which has T = list[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
@@ -48,12 +48,12 @@ class VllmModel(Protocol[T_co]):
...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
model_init = model.__init__
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
@@ -75,7 +75,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
...
@@ -85,8 +85,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
def is_vllm_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
@@ -105,7 +105,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
@overload
def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
...
@@ -116,8 +116,8 @@ def is_text_generation_model(
def is_text_generation_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model):
return False
@@ -142,7 +142,7 @@ class VllmModelForPooling(VllmModel[T], Protocol[T]):
@overload
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
...
@@ -152,8 +152,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
def is_pooling_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
if not is_vllm_model(model):
return False