[Bugfix][VLM] Fix transformers backend embed_multimodal for Qwen2.5-VL profiling (#32969)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-01-25 18:34:05 -06:00
committed by GitHub
parent a698e8e7ad
commit 22aeb43007

View File

@@ -386,19 +386,44 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
vision_embeddings = vision_embeddings.pooler_output
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
split_sizes = num_image_patches.flatten().tolist()
total_patches = sum(split_sizes)
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_image_patches.flatten().tolist()
)
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
# Flatten to 2D: [total_tokens, hidden_dim]
if vision_embeddings.ndim == 3:
vision_embeddings = vision_embeddings.view(
-1, vision_embeddings.shape[-1]
)
total_tokens = vision_embeddings.shape[0]
if total_tokens == total_patches:
# Direct match: num_image_patches are actual token counts
# (e.g., Qwen2.5-VL style)
token_split_sizes = split_sizes
elif total_patches > 0 and total_tokens % total_patches == 0:
# Uniform expansion: each patch expands to N tokens
# (e.g., Idefics3 style)
tokens_per_patch = total_tokens // total_patches
token_split_sizes = [s * tokens_per_patch for s in split_sizes]
elif total_patches > 0:
# Mismatch (profiling with dummy data) - pad/truncate
if total_tokens == 0:
raise ValueError(
"Vision encoder returned empty embeddings. "
f"Expected {total_patches} patches from "
f"num_image_patches={split_sizes}"
)
if total_tokens < total_patches:
repeat_factor = (
total_patches + total_tokens - 1
) // total_tokens
vision_embeddings = vision_embeddings.repeat(repeat_factor, 1)
vision_embeddings = vision_embeddings[:total_patches]
token_split_sizes = split_sizes
else:
return []
return list(torch.split(vision_embeddings, token_split_sizes, dim=0))
return vision_embeddings
else: