Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.sequence import SequenceGroupMetadata
|
from vllm.sequence import SequenceGroupMetadata
|
||||||
|
|
||||||
from .inputs import MultiModalKwargs, PlaceholderRange
|
from .inputs import MultiModalKwargs, NestedTensors, PlaceholderRange
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
@@ -56,7 +56,8 @@ class MultiModalPlaceholderMap:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(
|
def from_seq_group(
|
||||||
cls, seq_group: "SequenceGroupMetadata", positions: range
|
cls, seq_group: "SequenceGroupMetadata", positions: range
|
||||||
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
|
) -> tuple[dict[str, NestedTensors], dict[str,
|
||||||
|
"MultiModalPlaceholderMap"]]:
|
||||||
"""
|
"""
|
||||||
Returns the multi-modal items that intersect with the portion of a
|
Returns the multi-modal items that intersect with the portion of a
|
||||||
prompt (``seq_group``) represented by ``positions``, as well as a
|
prompt (``seq_group``) represented by ``positions``, as well as a
|
||||||
@@ -99,7 +100,7 @@ class MultiModalPlaceholderMap:
|
|||||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||||
|
|
||||||
if not seq_mm_data or not seq_mm_placeholders:
|
if not seq_mm_data or not seq_mm_placeholders:
|
||||||
return MultiModalKwargs(), {}
|
return MultiModalKwargs().get_data(), {}
|
||||||
|
|
||||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||||
|
|
||||||
@@ -116,6 +117,8 @@ class MultiModalPlaceholderMap:
|
|||||||
|
|
||||||
placeholder_maps[modality] = placeholder_map
|
placeholder_maps[modality] = placeholder_map
|
||||||
|
|
||||||
|
seq_mm_data = seq_mm_data if isinstance(
|
||||||
|
seq_mm_data, dict) else seq_mm_data.get_data()
|
||||||
return seq_mm_data, placeholder_maps
|
return seq_mm_data, placeholder_maps
|
||||||
|
|
||||||
def append_items_from_seq_group(
|
def append_items_from_seq_group(
|
||||||
|
|||||||
@@ -664,7 +664,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
def modality(self) -> str:
|
def modality(self) -> str:
|
||||||
return self._modality
|
return self._modality
|
||||||
|
|
||||||
def get_data(self) -> Mapping[str, NestedTensors]:
|
def get_data(self) -> dict[str, NestedTensors]:
|
||||||
return {key: elem.data for key, elem in self.items()}
|
return {key: elem.data for key, elem in self.items()}
|
||||||
|
|
||||||
|
|
||||||
@@ -720,7 +720,7 @@ class MultiModalKwargs:
|
|||||||
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
||||||
self._items_by_modality = dict(items_by_modality)
|
self._items_by_modality = dict(items_by_modality)
|
||||||
|
|
||||||
self._data: Optional[Mapping[str, NestedTensors]] = None
|
self._data: Optional[dict[str, NestedTensors]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def modalities(self):
|
def modalities(self):
|
||||||
@@ -883,7 +883,7 @@ class MultiModalKwargs:
|
|||||||
|
|
||||||
def get_data(self,
|
def get_data(self,
|
||||||
*,
|
*,
|
||||||
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
|
pin_memory: bool = False) -> dict[str, NestedTensors]:
|
||||||
if self._data is not None:
|
if self._data is not None:
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorOutput)
|
KVConnectorOutput)
|
||||||
|
|
||||||
@@ -978,7 +979,8 @@ class SequenceGroupMetadata(
|
|||||||
state: Optional[SequenceGroupState] = msgspec.field(
|
state: Optional[SequenceGroupState] = msgspec.field(
|
||||||
default_factory=lambda: SequenceGroupState())
|
default_factory=lambda: SequenceGroupState())
|
||||||
token_type_ids: Optional[list[int]] = None
|
token_type_ids: Optional[list[int]] = None
|
||||||
multi_modal_data: Optional[MultiModalKwargs] = None
|
multi_modal_data: Optional[Union[MultiModalKwargs,
|
||||||
|
dict[str, "NestedTensors"]]] = None
|
||||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||||
encoder_seq_data: Optional[SequenceData] = None
|
encoder_seq_data: Optional[SequenceData] = None
|
||||||
cross_block_table: Optional[list[int]] = None
|
cross_block_table: Optional[list[int]] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user