Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
|
||||
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
|
||||
runtime_checkable)
|
||||
|
||||
import torch
|
||||
@@ -20,7 +20,7 @@ logger = init_logger(__name__)
|
||||
|
||||
# The type of hidden states
|
||||
# Currently, T = torch.Tensor for all models except for Medusa
|
||||
# which has T = List[torch.Tensor]
|
||||
# which has T = list[torch.Tensor]
|
||||
T = TypeVar("T", default=torch.Tensor)
|
||||
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
|
||||
|
||||
@@ -48,12 +48,12 @@ class VllmModel(Protocol[T_co]):
|
||||
...
|
||||
|
||||
|
||||
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
|
||||
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
||||
model_init = model.__init__
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
||||
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
return False
|
||||
@@ -75,7 +75,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
||||
|
||||
|
||||
@overload
|
||||
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
|
||||
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
|
||||
...
|
||||
|
||||
|
||||
@@ -85,8 +85,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
||||
|
||||
|
||||
def is_vllm_model(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
||||
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
|
||||
|
||||
@overload
|
||||
def is_text_generation_model(
|
||||
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
|
||||
model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
|
||||
...
|
||||
|
||||
|
||||
@@ -116,8 +116,8 @@ def is_text_generation_model(
|
||||
|
||||
|
||||
def is_text_generation_model(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
|
||||
TypeIs[VllmModelForTextGeneration]]:
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
@@ -142,7 +142,7 @@ class VllmModelForPooling(VllmModel[T], Protocol[T]):
|
||||
|
||||
|
||||
@overload
|
||||
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
|
||||
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
|
||||
...
|
||||
|
||||
|
||||
@@ -152,8 +152,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
|
||||
|
||||
|
||||
def is_pooling_model(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user