@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
||||
Union, overload)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -55,6 +56,44 @@ def init_vllm_registered_model(
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
concat: Literal[True],
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
concat: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
|
||||
|
||||
The input tensor should have shape ``(B, N, ...)```.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.flatten(0, 1)
|
||||
|
||||
if concat:
|
||||
return torch.cat(x)
|
||||
|
||||
return [x_n for x_b in x for x_n in x_b]
|
||||
|
||||
|
||||
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
|
||||
"""
|
||||
Recursively concatenates NestedTensors along any heterogeneously sized
|
||||
@@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
mask = (input_ids == placeholder_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
num_expected_tokens = mask.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
*dims, embed_dim = flattened.shape
|
||||
|
||||
Reference in New Issue
Block a user