[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-18 17:52:00 +08:00
committed by GitHub
parent 5c79b0d648
commit 27e8d1ea3e
77 changed files with 431 additions and 383 deletions

View File

@@ -310,7 +310,7 @@ class Processor:
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
orig_sorted_mm_inputs = [
decoder_mm_inputs.get_item(modality, idx)
decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs
]
sorted_mm_positions = [

View File

@@ -18,12 +18,15 @@ from msgspec import msgpack
from vllm import envs
from vllm.logger import init_logger
# yapf: disable
from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem,
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors)
# yapf: enable
from vllm.v1.engine import UtilityResult
logger = init_logger(__name__)
@@ -116,12 +119,11 @@ class MsgpackEncoder:
if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargsItems):
return self._encode_mm_items(obj)
if isinstance(obj, MultiModalKwargs):
return [
self._encode_mm_item(item)
for itemlist in obj._items_by_modality.values()
for item in itemlist
]
return self._encode_mm_kwargs(obj)
if isinstance(obj, UtilityResult):
result = obj.result
@@ -183,6 +185,12 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
return {
modality: [self._encode_mm_item(item) for item in itemlist]
for modality, itemlist in items.items()
}
def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
@@ -200,6 +208,12 @@ class MsgpackEncoder:
self._encode_mm_field(elem.field),
}
def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
return {
modality: self._encode_nested_tensors(data)
for modality, data in kw.items()
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt)
@@ -260,8 +274,10 @@ class MsgpackDecoder:
return slice(*obj)
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargsItems):
return self._decode_mm_items(obj)
if issubclass(t, MultiModalKwargs):
return MultiModalKwargs(self._decode_mm_items(obj))
return self._decode_mm_kwargs(obj)
if t is UtilityResult:
return self._decode_utility_result(obj)
return obj
@@ -315,8 +331,11 @@ class MsgpackDecoder:
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
return [self._decode_mm_item(v) for v in obj]
def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
return MultiModalKwargsItems({
modality: [self._decode_mm_item(item) for item in itemlist]
for modality, itemlist in obj.items()
})
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems(
@@ -339,6 +358,12 @@ class MsgpackDecoder:
obj["field"] = factory_meth(None, *field_args).field
return MultiModalFieldElem(**obj)
def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
return MultiModalKwargs({
modality: self._decode_nested_tensors(data)
for modality, data in obj.items()
})
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs

View File

@@ -10,8 +10,8 @@ import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
@@ -57,8 +57,10 @@ class CachedRequestState:
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:

View File

@@ -2218,11 +2218,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
))

View File

@@ -1824,11 +1824,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
))