[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.loader import build_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal import BatchedTensors
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@@ -54,9 +55,34 @@ def init_vllm_registered_model(
|
||||
)
|
||||
|
||||
|
||||
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
|
||||
"""
|
||||
Recursively concatenates NestedTensors along any heterogeneously sized
|
||||
dimensions.
|
||||
"""
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
return embeddings
|
||||
|
||||
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
|
||||
|
||||
|
||||
def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||
"""
|
||||
Constructs a debugging representation of the number of embeddings in the
|
||||
NestedTensors.
|
||||
"""
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
|
||||
|
||||
return " + ".join(
|
||||
_embedding_count_expression(inner) for inner in embeddings)
|
||||
|
||||
|
||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: BatchedTensors,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
placeholder_token_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
@@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
mask = (input_ids == placeholder_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
|
||||
if isinstance(multimodal_embeddings, torch.Tensor):
|
||||
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
|
||||
total_tokens = batch_size * batch_tokens
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = f"{batch_size} x {batch_tokens}"
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = multimodal_embeddings.view(
|
||||
total_tokens, embed_dim)
|
||||
else:
|
||||
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
|
||||
total_tokens = sum(size_per_batch)
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = ' + '.join(map(str, size_per_batch))
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
*dims, embed_dim = flattened.shape
|
||||
num_multimodal_embeddings = np.prod(dims)
|
||||
if num_multimodal_embeddings != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user