[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -1,11 +1,17 @@
|
||||
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
||||
Union, overload, runtime_checkable)
|
||||
import inspect
|
||||
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
||||
Protocol, Type, Union, overload, runtime_checkable)
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||
def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
|
||||
...
|
||||
|
||||
|
||||
@@ -32,7 +38,7 @@ class SupportsMultiModal(Protocol):
|
||||
class _SupportsMultiModalType(Protocol):
|
||||
supports_multimodal: Literal[True]
|
||||
|
||||
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||
def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
|
||||
...
|
||||
|
||||
|
||||
@@ -75,7 +81,7 @@ class SupportsLoRA(Protocol):
|
||||
embedding_padding_modules: ClassVar[List[str]]
|
||||
|
||||
# lora_config is None when LoRA is not enabled
|
||||
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||
def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol):
|
||||
embedding_modules: Dict[str, str]
|
||||
embedding_padding_modules: List[str]
|
||||
|
||||
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||
def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@@ -145,6 +151,132 @@ def _supports_lora(
|
||||
return isinstance(model, SupportsLoRA)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsPP(Protocol):
|
||||
"""The interface required for all models that support pipeline parallel."""
|
||||
|
||||
supports_pp: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports pipeline parallel.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IntermediateTensors":
|
||||
"""Called when PP rank > 0 for profiling purposes."""
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
intermediate_tensors: Optional["IntermediateTensors"],
|
||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||
"""
|
||||
Accept :class:`IntermediateTensors` when PP rank > 0.
|
||||
|
||||
Return :class:`IntermediateTensors` only for the last PP rank.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
# so we need to treat the class as an instance and use isinstance instead
|
||||
@runtime_checkable
|
||||
class _SupportsPPType(Protocol):
|
||||
supports_pp: Literal[True]
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IntermediateTensors":
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
intermediate_tensors: Optional["IntermediateTensors"],
|
||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_pp(model: object) -> TypeIs[SupportsPP]:
|
||||
...
|
||||
|
||||
|
||||
def supports_pp(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
supports_attributes = _supports_pp_attributes(model)
|
||||
supports_inspect = _supports_pp_inspect(model)
|
||||
|
||||
if supports_attributes and not supports_inspect:
|
||||
logger.warning(
|
||||
"The model (%s) sets `supports_pp=True`, but does not accept "
|
||||
"`intermediate_tensors` in its `forward` method", model)
|
||||
|
||||
if not supports_attributes:
|
||||
pp_attrs = ("make_empty_intermediate_tensors", )
|
||||
missing_attrs = tuple(attr for attr in pp_attrs
|
||||
if not hasattr(model, attr))
|
||||
|
||||
if getattr(model, "supports_pp", False):
|
||||
if missing_attrs:
|
||||
logger.warning(
|
||||
"The model (%s) sets `supports_pp=True`, "
|
||||
"but is missing PP-specific attributes: %s",
|
||||
model,
|
||||
missing_attrs,
|
||||
)
|
||||
else:
|
||||
if not missing_attrs:
|
||||
logger.warning(
|
||||
"The model (%s) contains all PP-specific attributes, "
|
||||
"but does not set `supports_pp=True`.", model)
|
||||
|
||||
return supports_attributes and supports_inspect
|
||||
|
||||
|
||||
def _supports_pp_attributes(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsPPType)
|
||||
|
||||
return isinstance(model, SupportsPP)
|
||||
|
||||
|
||||
def _supports_pp_inspect(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
return False
|
||||
|
||||
forward_params = inspect.signature(model_forward).parameters
|
||||
return "intermediate_tensors" in forward_params
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasInnerState(Protocol):
|
||||
"""The interface required for all models that has inner state."""
|
||||
@@ -158,7 +290,7 @@ class HasInnerState(Protocol):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
||||
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
||||
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user