diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index f765d945c..686649733 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -386,19 +386,44 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): vision_embeddings = vision_embeddings.pooler_output if isinstance(vision_embeddings, torch.Tensor): - if vision_embeddings.ndim == 2: - vision_embeddings = vision_embeddings.unsqueeze(0) + split_sizes = num_image_patches.flatten().tolist() + total_patches = sum(split_sizes) - # Embeddings have to be 2D tensors of length `num_images` - # but transformers returns concat tensors if each patch - # is of different size. We split it back to make vLLM happy - vision_embeddings = torch.split( - vision_embeddings, num_image_patches.flatten().tolist() - ) - vision_embeddings = [ - embed.flatten(start_dim=0, end_dim=-2) - for embed in vision_embeddings - ] + # Flatten to 2D: [total_tokens, hidden_dim] + if vision_embeddings.ndim == 3: + vision_embeddings = vision_embeddings.view( + -1, vision_embeddings.shape[-1] + ) + + total_tokens = vision_embeddings.shape[0] + if total_tokens == total_patches: + # Direct match: num_image_patches are actual token counts + # (e.g., Qwen2.5-VL style) + token_split_sizes = split_sizes + elif total_patches > 0 and total_tokens % total_patches == 0: + # Uniform expansion: each patch expands to N tokens + # (e.g., Idefics3 style) + tokens_per_patch = total_tokens // total_patches + token_split_sizes = [s * tokens_per_patch for s in split_sizes] + elif total_patches > 0: + # Mismatch (profiling with dummy data) - pad/truncate + if total_tokens == 0: + raise ValueError( + "Vision encoder returned empty embeddings. " + f"Expected {total_patches} patches from " + f"num_image_patches={split_sizes}" + ) + if total_tokens < total_patches: + repeat_factor = ( + total_patches + total_tokens - 1 + ) // total_tokens + vision_embeddings = vision_embeddings.repeat(repeat_factor, 1) + vision_embeddings = vision_embeddings[:total_patches] + token_split_sizes = split_sizes + else: + return [] + + return list(torch.split(vision_embeddings, token_split_sizes, dim=0)) return vision_embeddings else: