[Bugfix][VLM] Fix incompatibility between #7902 and #7230 (#7948)

This commit is contained in:
Cyrus Leung
2024-08-28 23:11:18 +08:00
committed by GitHub
parent 98c12cffe5
commit ef9baee3c5
10 changed files with 120 additions and 92 deletions

View File

@@ -1,4 +1,5 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
import numpy as np
import torch
@@ -55,6 +56,44 @@ def init_vllm_registered_model(
)
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
...
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if isinstance(x, torch.Tensor):
return x.flatten(0, 1)
if concat:
return torch.cat(x)
return [x_n for x_b in x for x_n in x_b]
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
@@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
num_expected_tokens = mask.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape