[Model] Move multimodal_cpu_fields definition to field config (#30181)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
|
|||||||
modality=modality,
|
modality=modality,
|
||||||
key=key,
|
key=key,
|
||||||
data=torch.empty((size,), dtype=torch.int8),
|
data=torch.empty((size,), dtype=torch.int8),
|
||||||
field=MultiModalSharedField(1),
|
field=MultiModalSharedField(batch_size=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def _dummy_elem(
|
|||||||
modality=modality,
|
modality=modality,
|
||||||
key=key,
|
key=key,
|
||||||
data=data,
|
data=data,
|
||||||
field=MultiModalSharedField(1),
|
field=MultiModalSharedField(batch_size=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
|
|||||||
|
|
||||||
def test_multimodal_kwargs():
|
def test_multimodal_kwargs():
|
||||||
e1 = MultiModalFieldElem(
|
e1 = MultiModalFieldElem(
|
||||||
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()
|
"audio",
|
||||||
|
"a0",
|
||||||
|
torch.zeros(1000, dtype=torch.bfloat16),
|
||||||
|
MultiModalBatchedField(),
|
||||||
)
|
)
|
||||||
e2 = MultiModalFieldElem(
|
e2 = MultiModalFieldElem(
|
||||||
"video",
|
"video",
|
||||||
"v0",
|
"v0",
|
||||||
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
||||||
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
|
MultiModalFlatField(
|
||||||
|
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
e3 = MultiModalFieldElem(
|
e3 = MultiModalFieldElem(
|
||||||
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)
|
"image",
|
||||||
|
"i0",
|
||||||
|
torch.zeros(1000, dtype=torch.int32),
|
||||||
|
MultiModalSharedField(batch_size=4),
|
||||||
)
|
)
|
||||||
e4 = MultiModalFieldElem(
|
e4 = MultiModalFieldElem(
|
||||||
"image",
|
"image",
|
||||||
"i1",
|
"i1",
|
||||||
torch.zeros(1000, dtype=torch.int32),
|
torch.zeros(1000, dtype=torch.int32),
|
||||||
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2),
|
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
|
||||||
)
|
)
|
||||||
audio = MultiModalKwargsItem.from_elems([e1])
|
audio = MultiModalKwargsItem.from_elems([e1])
|
||||||
video = MultiModalKwargsItem.from_elems([e2])
|
video = MultiModalKwargsItem.from_elems([e2])
|
||||||
@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
|
|||||||
|
|
||||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||||
|
|
||||||
# expected total encoding length, should be 14306, +-20 for minor changes
|
# expected total encoding length, should be 14395, +-20 for minor changes
|
||||||
assert 14275 <= total_len <= 14325
|
assert 14375 <= total_len <= 14425
|
||||||
decoded = decoder.decode(encoded).mm[0]
|
decoded = decoder.decode(encoded).mm[0]
|
||||||
assert isinstance(decoded, MultiModalKwargsItems)
|
assert isinstance(decoded, MultiModalKwargsItems)
|
||||||
|
|
||||||
|
|||||||
@@ -787,10 +787,10 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: list[list[int]],
|
grid_thw: torch.Tensor | list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Convert grid_thw to tensor (always expecting list format now)
|
if isinstance(grid_thw, list):
|
||||||
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
|
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||||
|
|
||||||
# patchify
|
# patchify
|
||||||
x = x.to(device=self.device, dtype=self.dtype)
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
@@ -805,7 +805,8 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
||||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
@@ -1548,7 +1549,6 @@ class Glm4vForConditionalGeneration(
|
|||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
grid_thw = image_input["image_grid_thw"]
|
grid_thw = image_input["image_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||||
@@ -1559,12 +1559,10 @@ class Glm4vForConditionalGeneration(
|
|||||||
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
|
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||||
|
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
return image_embeds.split(sizes)
|
return image_embeds.split(sizes)
|
||||||
|
|
||||||
def _process_video_input(
|
def _process_video_input(
|
||||||
@@ -1572,7 +1570,6 @@ class Glm4vForConditionalGeneration(
|
|||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
grid_thw = video_input["video_grid_thw"]
|
grid_thw = video_input["video_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||||
@@ -1588,15 +1585,11 @@ class Glm4vForConditionalGeneration(
|
|||||||
rope_type="rope_3d",
|
rope_type="rope_3d",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
video_embeds = self.visual(
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
pixel_values_videos, grid_thw=grid_thw.tolist()
|
|
||||||
)
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
return video_embeds.split(sizes)
|
return video_embeds.split(sizes)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
|||||||
@@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
|||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
|
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
|
||||||
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
|
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
|
||||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration(
|
|||||||
SupportsQuant,
|
SupportsQuant,
|
||||||
SupportsXDRoPE,
|
SupportsXDRoPE,
|
||||||
):
|
):
|
||||||
multimodal_cpu_fields = {"image_grid_thw"}
|
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
|
|||||||
@@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol):
|
|||||||
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
|
multimodal_cpu_fields: ClassVar[Set[str] | None] = None
|
||||||
"""
|
"""
|
||||||
A set indicating CPU-only multimodal fields.
|
[DEPRECATED] A set indicating CPU-only multimodal fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_processor_factory: ClassVar[_ProcessorFactories]
|
_processor_factory: ClassVar[_ProcessorFactories]
|
||||||
@@ -279,6 +279,15 @@ def supports_multimodal(
|
|||||||
"please remove the override from your model."
|
"please remove the override from your model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
|
||||||
|
if multimodal_cpu_fields is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`multimodal_cpu_fields` is no longer effective, "
|
||||||
|
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
|
||||||
|
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
|
||||||
|
"and then remove the override from your model."
|
||||||
|
)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -201,8 +201,6 @@ class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder):
|
|||||||
dummy_inputs=OpenCUADummyInputsBuilder,
|
dummy_inputs=OpenCUADummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||||
multimodal_cpu_fields = {"image_grid_thw"}
|
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
|||||||
@@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
SupportsMultiModalPruning,
|
SupportsMultiModalPruning,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
):
|
):
|
||||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
|||||||
@@ -811,14 +811,14 @@ def _create_qwen2vl_field_factory(
|
|||||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", image_embed_grid_sizes
|
"image", image_embed_grid_sizes
|
||||||
),
|
),
|
||||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"video", video_grid_sizes
|
"video", video_grid_sizes
|
||||||
),
|
),
|
||||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"video", video_embed_grid_sizes
|
"video", video_embed_grid_sizes
|
||||||
),
|
),
|
||||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
return _qwen2vl_field_config
|
return _qwen2vl_field_config
|
||||||
@@ -1131,8 +1131,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
|
|||||||
class Qwen2VLForConditionalGeneration(
|
class Qwen2VLForConditionalGeneration(
|
||||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@@ -1393,9 +1391,11 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"]
|
pixel_values_videos = video_input["pixel_values_videos"]
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
return run_dp_sharded_mrope_vision_model(
|
return run_dp_sharded_mrope_vision_model(
|
||||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
self.visual,
|
||||||
|
pixel_values_videos,
|
||||||
|
grid_thw.tolist(),
|
||||||
|
rope_type="rope_3d",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
|
|||||||
@@ -984,14 +984,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
|
|||||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", image_grid_sizes
|
"image", image_grid_sizes
|
||||||
),
|
),
|
||||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"video", video_grid_sizes
|
"video", video_grid_sizes
|
||||||
),
|
),
|
||||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"video", video_grid_sizes
|
"video", video_grid_sizes
|
||||||
),
|
),
|
||||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
):
|
):
|
||||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping, Sequence, Set
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
@@ -223,6 +223,23 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
|||||||
return a == b
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
|
def _nested_tensors_h2d(
|
||||||
|
tensors: NestedTensors,
|
||||||
|
device: torch.types.Device,
|
||||||
|
) -> NestedTensors:
|
||||||
|
if device is None:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
return json_map_leaves(
|
||||||
|
(
|
||||||
|
lambda x: x.to(device=device, non_blocking=True)
|
||||||
|
if isinstance(x, torch.Tensor)
|
||||||
|
else x
|
||||||
|
),
|
||||||
|
tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
|
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
|
||||||
"""
|
"""
|
||||||
A dictionary containing nested tensors which have been batched via
|
A dictionary containing nested tensors which have been batched via
|
||||||
@@ -334,7 +351,7 @@ class MultiModalFieldElem:
|
|||||||
) # noqa: E721
|
) # noqa: E721
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class BaseMultiModalField(ABC):
|
class BaseMultiModalField(ABC):
|
||||||
"""
|
"""
|
||||||
Defines how to interpret tensor data belonging to a keyword argument in
|
Defines how to interpret tensor data belonging to a keyword argument in
|
||||||
@@ -342,6 +359,12 @@ class BaseMultiModalField(ABC):
|
|||||||
multi-modal items, and vice versa.
|
multi-modal items, and vice versa.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
keep_on_cpu: bool = False
|
||||||
|
"""
|
||||||
|
If `True`, then this field is excluded from being moved to the accelerator
|
||||||
|
when `MultiModalKwargsItems.get_data()` is called to batch the data.
|
||||||
|
"""
|
||||||
|
|
||||||
def _field_factory(self, *, modality: str, key: str):
|
def _field_factory(self, *, modality: str, key: str):
|
||||||
f = partial(
|
f = partial(
|
||||||
MultiModalFieldElem,
|
MultiModalFieldElem,
|
||||||
@@ -386,6 +409,7 @@ class BaseMultiModalField(ABC):
|
|||||||
self,
|
self,
|
||||||
elems: list[MultiModalFieldElem],
|
elems: list[MultiModalFieldElem],
|
||||||
*,
|
*,
|
||||||
|
device: torch.types.Device = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
) -> NestedTensors:
|
) -> NestedTensors:
|
||||||
"""
|
"""
|
||||||
@@ -399,11 +423,17 @@ class BaseMultiModalField(ABC):
|
|||||||
if len(set(field_types)) > 1:
|
if len(set(field_types)) > 1:
|
||||||
raise ValueError(f"Cannot merge different {field_types=}")
|
raise ValueError(f"Cannot merge different {field_types=}")
|
||||||
|
|
||||||
|
if device is not None and self.keep_on_cpu:
|
||||||
|
device = "cpu"
|
||||||
|
if pin_memory and self.keep_on_cpu:
|
||||||
|
pin_memory = False
|
||||||
|
|
||||||
batch = [elem.data for elem in elems]
|
batch = [elem.data for elem in elems]
|
||||||
return self._reduce_data(batch, pin_memory=pin_memory)
|
out = self._reduce_data(batch, pin_memory=pin_memory)
|
||||||
|
return _nested_tensors_h2d(out, device=device)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class MultiModalBatchedField(BaseMultiModalField):
|
class MultiModalBatchedField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
Info:
|
Info:
|
||||||
@@ -445,7 +475,7 @@ class MultiModalBatchedField(BaseMultiModalField):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class MultiModalFlatField(BaseMultiModalField):
|
class MultiModalFlatField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
Info:
|
Info:
|
||||||
@@ -505,7 +535,7 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
return [e for elem in batch for e in elem]
|
return [e for elem in batch for e in elem]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class MultiModalSharedField(BaseMultiModalField):
|
class MultiModalSharedField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
Info:
|
Info:
|
||||||
@@ -532,9 +562,10 @@ class MultiModalSharedField(BaseMultiModalField):
|
|||||||
return batch[0]
|
return batch[0]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
class MultiModalFieldConfig:
|
class MultiModalFieldConfig:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batched(modality: str):
|
def batched(modality: str, *, keep_on_cpu: bool = False):
|
||||||
"""
|
"""
|
||||||
Defines a field where an element in the batch is obtained by
|
Defines a field where an element in the batch is obtained by
|
||||||
indexing into the first dimension of the underlying data.
|
indexing into the first dimension of the underlying data.
|
||||||
@@ -542,6 +573,7 @@ class MultiModalFieldConfig:
|
|||||||
Args:
|
Args:
|
||||||
modality: The modality of the multi-modal item that uses this
|
modality: The modality of the multi-modal item that uses this
|
||||||
keyword argument.
|
keyword argument.
|
||||||
|
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -558,7 +590,7 @@ class MultiModalFieldConfig:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return MultiModalFieldConfig(
|
return MultiModalFieldConfig(
|
||||||
field=MultiModalBatchedField(),
|
field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
|
||||||
modality=modality,
|
modality=modality,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -567,6 +599,8 @@ class MultiModalFieldConfig:
|
|||||||
modality: str,
|
modality: str,
|
||||||
slices: Sequence[slice] | Sequence[Sequence[slice]],
|
slices: Sequence[slice] | Sequence[Sequence[slice]],
|
||||||
dim: int = 0,
|
dim: int = 0,
|
||||||
|
*,
|
||||||
|
keep_on_cpu: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Defines a field where an element in the batch is obtained by
|
Defines a field where an element in the batch is obtained by
|
||||||
@@ -579,6 +613,7 @@ class MultiModalFieldConfig:
|
|||||||
slices (dim>0) that is used to extract the data corresponding
|
slices (dim>0) that is used to extract the data corresponding
|
||||||
to it.
|
to it.
|
||||||
dim: The dimension to extract data, default to 0.
|
dim: The dimension to extract data, default to 0.
|
||||||
|
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -613,12 +648,22 @@ class MultiModalFieldConfig:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return MultiModalFieldConfig(
|
return MultiModalFieldConfig(
|
||||||
field=MultiModalFlatField(slices=slices, dim=dim),
|
field=MultiModalFlatField(
|
||||||
|
slices=slices,
|
||||||
|
dim=dim,
|
||||||
|
keep_on_cpu=keep_on_cpu,
|
||||||
|
),
|
||||||
modality=modality,
|
modality=modality,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
|
def flat_from_sizes(
|
||||||
|
modality: str,
|
||||||
|
size_per_item: "torch.Tensor",
|
||||||
|
dim: int = 0,
|
||||||
|
*,
|
||||||
|
keep_on_cpu: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Defines a field where an element in the batch is obtained by
|
Defines a field where an element in the batch is obtained by
|
||||||
slicing along the first dimension of the underlying data.
|
slicing along the first dimension of the underlying data.
|
||||||
@@ -629,6 +674,7 @@ class MultiModalFieldConfig:
|
|||||||
size_per_item: For each multi-modal item, the size of the slice
|
size_per_item: For each multi-modal item, the size of the slice
|
||||||
that is used to extract the data corresponding to it.
|
that is used to extract the data corresponding to it.
|
||||||
dim: The dimension to slice, default to 0.
|
dim: The dimension to slice, default to 0.
|
||||||
|
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -676,10 +722,20 @@ class MultiModalFieldConfig:
|
|||||||
for i in range(len(size_per_item))
|
for i in range(len(size_per_item))
|
||||||
]
|
]
|
||||||
|
|
||||||
return MultiModalFieldConfig.flat(modality, slices, dim=dim)
|
return MultiModalFieldConfig.flat(
|
||||||
|
modality,
|
||||||
|
slices,
|
||||||
|
dim=dim,
|
||||||
|
keep_on_cpu=keep_on_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shared(modality: str, batch_size: int):
|
def shared(
|
||||||
|
modality: str,
|
||||||
|
batch_size: int,
|
||||||
|
*,
|
||||||
|
keep_on_cpu: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Defines a field where an element in the batch is obtained by
|
Defines a field where an element in the batch is obtained by
|
||||||
taking the entirety of the underlying data.
|
taking the entirety of the underlying data.
|
||||||
@@ -690,6 +746,7 @@ class MultiModalFieldConfig:
|
|||||||
modality: The modality of the multi-modal item that uses this
|
modality: The modality of the multi-modal item that uses this
|
||||||
keyword argument.
|
keyword argument.
|
||||||
batch_size: The number of multi-modal items which share this data.
|
batch_size: The number of multi-modal items which share this data.
|
||||||
|
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -708,18 +765,15 @@ class MultiModalFieldConfig:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return MultiModalFieldConfig(
|
return MultiModalFieldConfig(
|
||||||
field=MultiModalSharedField(batch_size),
|
field=MultiModalSharedField(
|
||||||
|
batch_size=batch_size,
|
||||||
|
keep_on_cpu=keep_on_cpu,
|
||||||
|
),
|
||||||
modality=modality,
|
modality=modality,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
|
field: BaseMultiModalField
|
||||||
super().__init__()
|
modality: str
|
||||||
|
|
||||||
self.field = field
|
|
||||||
self.modality = modality
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"
|
|
||||||
|
|
||||||
def build_elems(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
@@ -744,7 +798,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
modality=modality,
|
modality=modality,
|
||||||
key="dummy",
|
key="dummy",
|
||||||
data=torch.empty(nbytes, dtype=torch.uint8),
|
data=torch.empty(nbytes, dtype=torch.uint8),
|
||||||
field=MultiModalSharedField(1),
|
field=MultiModalSharedField(batch_size=1),
|
||||||
)
|
)
|
||||||
return MultiModalKwargsItem.from_elems([mm_elem])
|
return MultiModalKwargsItem.from_elems([mm_elem])
|
||||||
|
|
||||||
@@ -844,7 +898,6 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
|||||||
*,
|
*,
|
||||||
device: torch.types.Device = None,
|
device: torch.types.Device = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
cpu_fields: Set[str] = frozenset(),
|
|
||||||
) -> BatchedTensorInputs:
|
) -> BatchedTensorInputs:
|
||||||
"""Construct a dictionary of keyword arguments to pass to the model."""
|
"""Construct a dictionary of keyword arguments to pass to the model."""
|
||||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||||
@@ -859,21 +912,14 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
|||||||
elems_by_key[key].append(elem)
|
elems_by_key[key].append(elem)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
key: elems[0].field.reduce_data(
|
||||||
|
elems,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
for key, elems in elems_by_key.items()
|
for key, elems in elems_by_key.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
for k in data.keys() - cpu_fields:
|
|
||||||
data[k] = json_map_leaves(
|
|
||||||
(
|
|
||||||
lambda x: x.to(device=device, non_blocking=True)
|
|
||||||
if isinstance(x, torch.Tensor)
|
|
||||||
else x
|
|
||||||
),
|
|
||||||
data[k],
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality(
|
|||||||
device: torch.types.Device = None,
|
device: torch.types.Device = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
merge_by_field_config: bool | None = None,
|
merge_by_field_config: bool | None = None,
|
||||||
multimodal_cpu_fields: Set[str] = frozenset(),
|
multimodal_cpu_fields: Set[str] | None = None,
|
||||||
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
|
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
|
||||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||||
modality together into the same `MultiModalKwargs` instance.
|
modality together into the same `MultiModalKwargs` instance.
|
||||||
@@ -431,6 +431,11 @@ def group_mm_kwargs_by_modality(
|
|||||||
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
|
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
|
||||||
"is deprecated and will be removed in v0.13."
|
"is deprecated and will be removed in v0.13."
|
||||||
)
|
)
|
||||||
|
if multimodal_cpu_fields is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
|
||||||
|
"is deprecated and will be removed in v0.13."
|
||||||
|
)
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
from vllm.multimodal.inputs import MultiModalKwargsItems
|
||||||
|
|
||||||
@@ -440,7 +445,6 @@ def group_mm_kwargs_by_modality(
|
|||||||
mm_kwargs_data = mm_kwargs_items.get_data(
|
mm_kwargs_data = mm_kwargs_items.get_data(
|
||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
cpu_fields=multimodal_cpu_fields,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield modality, len(items_lst), mm_kwargs_data
|
yield modality, len(items_lst), mm_kwargs_data
|
||||||
|
|||||||
@@ -269,10 +269,11 @@ class MsgpackEncoder:
|
|||||||
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
|
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
|
||||||
if not name:
|
if not name:
|
||||||
raise TypeError(f"Unsupported field type: {field.__class__}")
|
raise TypeError(f"Unsupported field type: {field.__class__}")
|
||||||
|
|
||||||
# We just need to copy all of the field values in order
|
# We just need to copy all of the field values in order
|
||||||
# which will be then used to reconstruct the field.
|
# which will be then used to reconstruct the field.
|
||||||
field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
|
factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
|
||||||
return name, *field_values
|
return name, factory_kw
|
||||||
|
|
||||||
|
|
||||||
class MsgpackDecoder:
|
class MsgpackDecoder:
|
||||||
@@ -392,15 +393,15 @@ class MsgpackDecoder:
|
|||||||
obj["data"] = self._decode_nested_tensors(obj["data"])
|
obj["data"] = self._decode_nested_tensors(obj["data"])
|
||||||
|
|
||||||
# Reconstruct the field processor using MultiModalFieldConfig
|
# Reconstruct the field processor using MultiModalFieldConfig
|
||||||
factory_meth_name, *field_args = obj["field"]
|
factory_meth_name, factory_kw = obj["field"]
|
||||||
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
|
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
|
||||||
|
|
||||||
# Special case: decode the union "slices" field of
|
# Special case: decode the union "slices" field of
|
||||||
# MultiModalFlatField
|
# MultiModalFlatField
|
||||||
if factory_meth_name == "flat":
|
if factory_meth_name == "flat":
|
||||||
field_args[0] = self._decode_nested_slices(field_args[0])
|
factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
|
||||||
|
|
||||||
obj["field"] = factory_meth(None, *field_args).field
|
obj["field"] = factory_meth("", **factory_kw).field
|
||||||
return MultiModalFieldElem(**obj)
|
return MultiModalFieldElem(**obj)
|
||||||
|
|
||||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||||
|
|||||||
@@ -1097,7 +1097,6 @@ class GPUModelRunner(
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
|
||||||
):
|
):
|
||||||
mm_kwargs_combined.update(mm_kwargs_group)
|
mm_kwargs_combined.update(mm_kwargs_group)
|
||||||
|
|
||||||
@@ -2109,7 +2108,6 @@ class GPUModelRunner(
|
|||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
|
||||||
):
|
):
|
||||||
curr_group_outputs: list[torch.Tensor] = []
|
curr_group_outputs: list[torch.Tensor] = []
|
||||||
|
|
||||||
@@ -2135,7 +2133,6 @@ class GPUModelRunner(
|
|||||||
[video_mm_kwargs_item],
|
[video_mm_kwargs_item],
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -3887,14 +3884,12 @@ class GPUModelRunner(
|
|||||||
dummy_mm_item = dummy_mm_data[modality][0]
|
dummy_mm_item = dummy_mm_data[modality][0]
|
||||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||||
|
|
||||||
model = cast(SupportsMultiModal, self.model)
|
|
||||||
return next(
|
return next(
|
||||||
mm_kwargs_group
|
mm_kwargs_group
|
||||||
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||||
dummy_mm_items,
|
dummy_mm_items,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -969,7 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
|
||||||
):
|
):
|
||||||
# Run the encoder.
|
# Run the encoder.
|
||||||
# `curr_group_outputs` is either of the following:
|
# `curr_group_outputs` is either of the following:
|
||||||
@@ -2050,14 +2049,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dummy_mm_item = dummy_mm_data[modality][0]
|
dummy_mm_item = dummy_mm_data[modality][0]
|
||||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||||
|
|
||||||
model = cast(SupportsMultiModal, self.model)
|
|
||||||
return next(
|
return next(
|
||||||
grouped_mm_kwargs
|
grouped_mm_kwargs
|
||||||
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
|
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
|
||||||
dummy_mm_items,
|
dummy_mm_items,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user