[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Cyrus Leung
2025-09-27 16:15:12 +08:00
committed by GitHub
parent 176173989a
commit 27d7638b94
80 changed files with 966 additions and 1139 deletions

View File

@@ -4,7 +4,7 @@
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
from typing import Any, Literal, Optional, Protocol, Union, overload
import torch
import torch.nn as nn
@@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@@ -402,63 +402,37 @@ def _merge_multimodal_embeddings(
Note:
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
flattened.to(dtype=inputs_embeds.dtype))
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if len(multimodal_embeddings) == 0:
return inputs_embeds
if flattened.shape[0] != num_expected_tokens:
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
@@ -491,23 +465,29 @@ def merge_multimodal_embeddings(
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(
placeholder_token_id,
pin_memory=is_pin_memory_available()).to(device=input_ids.device,
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
is_multimodal = isin_list(input_ids, placeholder_token_id)
else:
is_multimodal = (input_ids == placeholder_token_id)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: