[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)

This commit is contained in:
Peter Salas
2024-08-27 18:53:56 -07:00
committed by GitHub
parent 9c71c97ae2
commit fab5f53e2d
15 changed files with 214 additions and 60 deletions

View File

@@ -1,5 +1,6 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.func import functional_call
@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors
from vllm.multimodal.base import NestedTensors
from vllm.utils import is_pin_memory_available
@@ -54,9 +55,34 @@ def init_vllm_registered_model(
)
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
dimensions.
"""
if isinstance(embeddings, torch.Tensor):
return embeddings
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape
num_multimodal_embeddings = np.prod(dims)
if num_multimodal_embeddings != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
return inputs_embeds