[Refactor] Defer tensor data construction in MultiModalKwargs (#23030)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-17 12:05:50 +08:00
committed by GitHub
parent 94096a47c9
commit 5c32143b9d
12 changed files with 73 additions and 104 deletions

View File

@@ -117,16 +117,9 @@ class MsgpackEncoder:
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
# just return the main dict if there are no modalities.
return dict(mm)
# ignore the main dict, it will be re-indexed.
# Any tensors *not* indexed by modality will be ignored.
return [
self._encode_mm_item(item)
for itemlist in mm._items_by_modality.values()
for itemlist in obj._items_by_modality.values()
for item in itemlist
]
@@ -268,13 +261,7 @@ class MsgpackDecoder:
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
self._decode_mm_items(obj))
return MultiModalKwargs({
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
return MultiModalKwargs(self._decode_mm_items(obj))
if t is UtilityResult:
return self._decode_utility_result(obj)
return obj

View File

@@ -58,7 +58,7 @@ class CachedRequestState:
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: