Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -38,8 +47,7 @@ class VllmModel(Protocol[T_co]):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -52,8 +60,7 @@ class VllmModel(Protocol[T_co]):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> T_co:
|
||||
...
|
||||
) -> T_co: ...
|
||||
|
||||
|
||||
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
||||
@@ -61,8 +68,7 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_get_input_embeddings(
|
||||
model: Union[type[object], object]) -> bool:
|
||||
def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool:
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if not callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
@@ -80,11 +86,9 @@ def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||
return False
|
||||
|
||||
vllm_kws = ("input_ids", "positions")
|
||||
missing_kws = tuple(kw for kw in vllm_kws
|
||||
if not supports_kw(model_forward, kw))
|
||||
missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw))
|
||||
|
||||
if missing_kws and (isinstance(model, type)
|
||||
and issubclass(model, nn.Module)):
|
||||
if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)):
|
||||
logger.warning(
|
||||
"The model (%s) is missing "
|
||||
"vLLM-specific keywords from its `forward` method: %s",
|
||||
@@ -96,21 +100,21 @@ def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||
|
||||
|
||||
@overload
|
||||
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
|
||||
...
|
||||
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
||||
...
|
||||
def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
|
||||
|
||||
|
||||
def is_vllm_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
||||
return (_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
and _check_vllm_model_forward(model))
|
||||
return (
|
||||
_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
and _check_vllm_model_forward(model)
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@@ -127,20 +131,19 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
|
||||
|
||||
@overload
|
||||
def is_text_generation_model(
|
||||
model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
|
||||
...
|
||||
model: type[object],
|
||||
) -> TypeIs[type[VllmModelForTextGeneration]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def is_text_generation_model(
|
||||
model: object) -> TypeIs[VllmModelForTextGeneration]:
|
||||
...
|
||||
def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ...
|
||||
|
||||
|
||||
def is_text_generation_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
|
||||
TypeIs[VllmModelForTextGeneration]]:
|
||||
) -> Union[
|
||||
TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]
|
||||
]:
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
@@ -179,13 +182,11 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
|
||||
|
||||
@overload
|
||||
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
|
||||
...
|
||||
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
|
||||
...
|
||||
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
|
||||
|
||||
|
||||
def is_pooling_model(
|
||||
|
||||
Reference in New Issue
Block a user