[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)

This commit is contained in:
Cyrus Leung
2024-10-07 14:10:35 +08:00
committed by GitHub
parent 18b296fdb2
commit 8c6de96ea1
10 changed files with 342 additions and 37 deletions

View File

@@ -1,4 +1,3 @@
import inspect
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
@@ -6,9 +5,9 @@ import torch
from typing_extensions import TypeIs
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors
@@ -142,9 +141,7 @@ def supports_lora(
return result
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
def _supports_lora(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
@@ -175,10 +172,7 @@ class SupportsPP(Protocol):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
@@ -205,10 +199,7 @@ class _SupportsPPType(Protocol):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
@@ -257,24 +248,19 @@ def supports_pp(
return supports_attributes and supports_inspect
def _supports_pp_attributes(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
forward_params = inspect.signature(model_forward).parameters
return "intermediate_tensors" in forward_params
return supports_kw(model_forward, "intermediate_tensors")
@runtime_checkable