[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -310,7 +310,7 @@ class Processor:
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
orig_sorted_mm_inputs = [
|
||||
decoder_mm_inputs.get_item(modality, idx)
|
||||
decoder_mm_inputs[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
sorted_mm_positions = [
|
||||
|
||||
@@ -18,12 +18,15 @@ from msgspec import msgpack
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
# yapf: disable
|
||||
from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||
MultiModalBatchedField,
|
||||
MultiModalFieldConfig, MultiModalFieldElem,
|
||||
MultiModalFlatField, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
# yapf: enable
|
||||
from vllm.v1.engine import UtilityResult
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -116,12 +119,11 @@ class MsgpackEncoder:
|
||||
if isinstance(obj, MultiModalKwargsItem):
|
||||
return self._encode_mm_item(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargsItems):
|
||||
return self._encode_mm_items(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
return [
|
||||
self._encode_mm_item(item)
|
||||
for itemlist in obj._items_by_modality.values()
|
||||
for item in itemlist
|
||||
]
|
||||
return self._encode_mm_kwargs(obj)
|
||||
|
||||
if isinstance(obj, UtilityResult):
|
||||
result = obj.result
|
||||
@@ -183,6 +185,12 @@ class MsgpackEncoder:
|
||||
dtype = str(obj.dtype).removeprefix("torch.")
|
||||
return dtype, obj.shape, data
|
||||
|
||||
def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
|
||||
return {
|
||||
modality: [self._encode_mm_item(item) for item in itemlist]
|
||||
for modality, itemlist in items.items()
|
||||
}
|
||||
|
||||
def _encode_mm_item(self,
|
||||
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
|
||||
return [self._encode_mm_field_elem(elem) for elem in item.values()]
|
||||
@@ -200,6 +208,12 @@ class MsgpackEncoder:
|
||||
self._encode_mm_field(elem.field),
|
||||
}
|
||||
|
||||
def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
|
||||
return {
|
||||
modality: self._encode_nested_tensors(data)
|
||||
for modality, data in kw.items()
|
||||
}
|
||||
|
||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||
if isinstance(nt, torch.Tensor):
|
||||
return self._encode_tensor(nt)
|
||||
@@ -260,8 +274,10 @@ class MsgpackDecoder:
|
||||
return slice(*obj)
|
||||
if issubclass(t, MultiModalKwargsItem):
|
||||
return self._decode_mm_item(obj)
|
||||
if issubclass(t, MultiModalKwargsItems):
|
||||
return self._decode_mm_items(obj)
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
return MultiModalKwargs(self._decode_mm_items(obj))
|
||||
return self._decode_mm_kwargs(obj)
|
||||
if t is UtilityResult:
|
||||
return self._decode_utility_result(obj)
|
||||
return obj
|
||||
@@ -315,8 +331,11 @@ class MsgpackDecoder:
|
||||
# Convert back to proper shape & type
|
||||
return arr.view(torch_dtype).view(shape)
|
||||
|
||||
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
|
||||
return [self._decode_mm_item(v) for v in obj]
|
||||
def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
|
||||
return MultiModalKwargsItems({
|
||||
modality: [self._decode_mm_item(item) for item in itemlist]
|
||||
for modality, itemlist in obj.items()
|
||||
})
|
||||
|
||||
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
|
||||
return MultiModalKwargsItem.from_elems(
|
||||
@@ -339,6 +358,12 @@ class MsgpackDecoder:
|
||||
obj["field"] = factory_meth(None, *field_args).field
|
||||
return MultiModalFieldElem(**obj)
|
||||
|
||||
def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
|
||||
return MultiModalKwargs({
|
||||
modality: self._decode_nested_tensors(data)
|
||||
for modality, data in obj.items()
|
||||
})
|
||||
|
||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||
if isinstance(obj, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
|
||||
@@ -10,8 +10,8 @@ import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, PlaceholderRange)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
@@ -57,8 +57,10 @@ class CachedRequestState:
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargs]:
|
||||
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
|
||||
]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
|
||||
@@ -2218,11 +2218,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||
|
||||
# Result in the maximum GPU consumption of the model
|
||||
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
|
||||
dummy_mm_item = dummy_mm_data[modality][0]
|
||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||
|
||||
return next(mm_kwargs_group
|
||||
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
[dummy_mm_item] * max_items_per_batch,
|
||||
dummy_mm_items,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
))
|
||||
|
||||
@@ -1824,11 +1824,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||
|
||||
# Result in the maximum GPU consumption of the model
|
||||
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
|
||||
dummy_mm_item = dummy_mm_data[modality][0]
|
||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||
|
||||
return next(grouped_mm_kwargs
|
||||
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
|
||||
[dummy_mm_item] * max_items_per_batch,
|
||||
dummy_mm_items,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user