[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import itertools
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
|
||||
from typing import Any, Literal, Optional, Protocol, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||
|
||||
def _merge_multimodal_embeddings(
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_multimodal: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
is_multimodal: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
@@ -402,63 +402,37 @@ def _merge_multimodal_embeddings(
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
try:
|
||||
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
|
||||
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
|
||||
flattened.to(dtype=inputs_embeds.dtype))
|
||||
except RuntimeError as e:
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
if len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
|
||||
input_dtype = inputs_embeds.dtype
|
||||
|
||||
try:
|
||||
# For debugging
|
||||
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
|
||||
|
||||
# NOTE: This can avoid D2H sync (#22105), but fails to
|
||||
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
|
||||
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
|
||||
mm_embeds_flat.to(dtype=input_dtype))
|
||||
except RuntimeError as e:
|
||||
num_actual_tokens = len(mm_embeds_flat)
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
|
||||
if num_actual_tokens != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"Attempted to assign {expr} = {num_actual_tokens} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def embed_multimodal(
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_token_id: int,
|
||||
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
|
||||
multimodal_embeds: NestedTensors,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Embed token IDs and multimodal inputs and combine their embeddings.
|
||||
|
||||
``multimodal_token_id`` is used to determine whether a token ID should
|
||||
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
|
||||
|
||||
Compared to ``merge_multimodal_embeddings`, this avoids running
|
||||
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
|
||||
which causes issues when the placeholder token ID exceeds the
|
||||
vocabulary size of the language model.
|
||||
"""
|
||||
is_multimodal = input_ids == multimodal_token_id
|
||||
is_text = ~is_multimodal
|
||||
|
||||
text_embeds = get_text_embeds(input_ids[is_text])
|
||||
merged_embeds = torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
)
|
||||
|
||||
merged_embeds[is_text] = text_embeds
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
merged_embeds,
|
||||
is_multimodal,
|
||||
multimodal_embeds,
|
||||
)
|
||||
|
||||
|
||||
def merge_multimodal_embeddings(
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
@@ -491,23 +465,29 @@ def merge_multimodal_embeddings(
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
if isinstance(placeholder_token_id, list):
|
||||
placeholder_token_id = torch.tensor(
|
||||
placeholder_token_id,
|
||||
pin_memory=is_pin_memory_available()).to(device=input_ids.device,
|
||||
non_blocking=True)
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
torch.isin(input_ids, placeholder_token_id),
|
||||
multimodal_embeddings,
|
||||
)
|
||||
is_multimodal = isin_list(input_ids, placeholder_token_id)
|
||||
else:
|
||||
is_multimodal = (input_ids == placeholder_token_id)
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
(input_ids == placeholder_token_id),
|
||||
multimodal_embeddings,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
|
||||
def isin_list(
|
||||
elements: torch.Tensor,
|
||||
test_elements_list: list[int],
|
||||
) -> torch.Tensor:
|
||||
test_elements = torch.tensor(
|
||||
test_elements_list,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
).to(device=elements.device, non_blocking=True)
|
||||
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(self, prefix: str) -> torch.nn.Module:
|
||||
|
||||
Reference in New Issue
Block a user