[Bugfix][VLM] Fix transformers backend embed_multimodal for Qwen2.5-VL profiling (#32969)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -386,19 +386,44 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
vision_embeddings = vision_embeddings.pooler_output
|
||||
|
||||
if isinstance(vision_embeddings, torch.Tensor):
|
||||
if vision_embeddings.ndim == 2:
|
||||
vision_embeddings = vision_embeddings.unsqueeze(0)
|
||||
split_sizes = num_image_patches.flatten().tolist()
|
||||
total_patches = sum(split_sizes)
|
||||
|
||||
# Embeddings have to be 2D tensors of length `num_images`
|
||||
# but transformers returns concat tensors if each patch
|
||||
# is of different size. We split it back to make vLLM happy
|
||||
vision_embeddings = torch.split(
|
||||
vision_embeddings, num_image_patches.flatten().tolist()
|
||||
)
|
||||
vision_embeddings = [
|
||||
embed.flatten(start_dim=0, end_dim=-2)
|
||||
for embed in vision_embeddings
|
||||
]
|
||||
# Flatten to 2D: [total_tokens, hidden_dim]
|
||||
if vision_embeddings.ndim == 3:
|
||||
vision_embeddings = vision_embeddings.view(
|
||||
-1, vision_embeddings.shape[-1]
|
||||
)
|
||||
|
||||
total_tokens = vision_embeddings.shape[0]
|
||||
if total_tokens == total_patches:
|
||||
# Direct match: num_image_patches are actual token counts
|
||||
# (e.g., Qwen2.5-VL style)
|
||||
token_split_sizes = split_sizes
|
||||
elif total_patches > 0 and total_tokens % total_patches == 0:
|
||||
# Uniform expansion: each patch expands to N tokens
|
||||
# (e.g., Idefics3 style)
|
||||
tokens_per_patch = total_tokens // total_patches
|
||||
token_split_sizes = [s * tokens_per_patch for s in split_sizes]
|
||||
elif total_patches > 0:
|
||||
# Mismatch (profiling with dummy data) - pad/truncate
|
||||
if total_tokens == 0:
|
||||
raise ValueError(
|
||||
"Vision encoder returned empty embeddings. "
|
||||
f"Expected {total_patches} patches from "
|
||||
f"num_image_patches={split_sizes}"
|
||||
)
|
||||
if total_tokens < total_patches:
|
||||
repeat_factor = (
|
||||
total_patches + total_tokens - 1
|
||||
) // total_tokens
|
||||
vision_embeddings = vision_embeddings.repeat(repeat_factor, 1)
|
||||
vision_embeddings = vision_embeddings[:total_patches]
|
||||
token_split_sizes = split_sizes
|
||||
else:
|
||||
return []
|
||||
|
||||
return list(torch.split(vision_embeddings, token_split_sizes, dim=0))
|
||||
|
||||
return vision_embeddings
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user