[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
||||
Protocol, Type, Union, overload, runtime_checkable)
|
||||
|
||||
@@ -6,9 +5,9 @@ import torch
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -142,9 +141,7 @@ def supports_lora(
|
||||
return result
|
||||
|
||||
|
||||
def _supports_lora(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
|
||||
def _supports_lora(model: Union[Type[object], object]) -> bool:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsLoRAType)
|
||||
|
||||
@@ -175,10 +172,7 @@ class SupportsPP(Protocol):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
*,
|
||||
intermediate_tensors: Optional["IntermediateTensors"],
|
||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||
"""
|
||||
@@ -205,10 +199,7 @@ class _SupportsPPType(Protocol):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
*,
|
||||
intermediate_tensors: Optional["IntermediateTensors"],
|
||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||
...
|
||||
@@ -257,24 +248,19 @@ def supports_pp(
|
||||
return supports_attributes and supports_inspect
|
||||
|
||||
|
||||
def _supports_pp_attributes(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsPPType)
|
||||
|
||||
return isinstance(model, SupportsPP)
|
||||
|
||||
|
||||
def _supports_pp_inspect(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
return False
|
||||
|
||||
forward_params = inspect.signature(model_forward).parameters
|
||||
return "intermediate_tensors" in forward_params
|
||||
return supports_kw(model_forward, "intermediate_tensors")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
Reference in New Issue
Block a user