[Model] Use merge_by_field_config for MM models (O-P) (#26776)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-14 17:42:45 +08:00
committed by GitHub
parent d2f816d6ff
commit 74704d4553
3 changed files with 30 additions and 122 deletions

View File

@@ -50,13 +50,12 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
@@ -467,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
data: Annotated[
pixel_values: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape(
"bn", "p", 3, "h", "w", dynamic_dims={"p"}
@@ -499,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
type: Literal["audio_features"]
data: Annotated[
audio_features: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
]
@@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
Implements the Phi-4-multimodal-instruct model in vLLM.
"""
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"qkv_proj",
@@ -1094,7 +1095,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if audio_features is not None:
return Phi4MMAudioFeatureInputs(
type="audio_features", data=flatten_bn(audio_features)
type="audio_features",
audio_features=audio_features,
)
if audio_embeds is not None:
@@ -1119,7 +1121,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if audio_input["type"] == "audio_embeds":
return audio_input["data"]
audio_features = audio_input["data"]
audio_features = audio_input["audio_features"]
# (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example)
@@ -1136,8 +1138,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Phi4MMImagePixelInputs | None:
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
if input_image_embeds is None:
pixel_values = kwargs.get("input_image_embeds")
if pixel_values is None:
return None
image_sizes = kwargs.get("image_sizes")
@@ -1149,52 +1151,9 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
and num_img_tokens is not None
), "Missing image inputs"
if is_list_of(input_image_embeds, torch.Tensor):
assert all(p.dim() == 5 for p in input_image_embeds), (
"Incorrect image inputs"
)
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
input_image_embeds = cat_with_pad(input_image_embeds, dim=0)
elif isinstance(input_image_embeds, torch.Tensor):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
assert input_image_embeds.dim() == 6, "Incorrect image inputs"
input_image_embeds = input_image_embeds.flatten(0, 1)
else:
raise ValueError("Incorrect input_image_embeds inputs")
if isinstance(image_attention_mask, list):
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
elif isinstance(image_attention_mask, torch.Tensor):
image_attention_mask = image_attention_mask.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
if isinstance(image_sizes, list):
image_sizes = torch.cat(image_sizes, dim=0)
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
]
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs(
type="pixel_values",
data=input_image_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
num_img_tokens=num_img_tokens,
@@ -1223,7 +1182,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self, image_input: Phi4MMImagePixelInputs
) -> list[torch.Tensor]:
dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = image_input["data"].to(dtype)
pixel_values = image_input["pixel_values"].to(dtype)
image_sizes = image_input["image_sizes"]
image_attention_mask = image_input["image_attention_mask"]
image_embeds = self.vision_encoder(