[Core] Simplify multimodal masking (#34246)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2026-04-01 09:18:22 +01:00
committed by GitHub
parent 36d7f19897
commit 4f6eed3bd4
9 changed files with 54 additions and 51 deletions

View File

@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol):
# to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
# we have multimodal tokens.
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
in_vocab_ids = input_ids.masked_fill(
is_multimodal.to(device=input_ids.device, non_blocking=True), 0
)
return embed_input_ids(in_vocab_ids)
return embed_input_ids(input_ids)

View File

@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2(
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
tokenizer = cached_tokenizer_from_config(self.model_config)
# Generate video replacement token IDs using get_video_repl
@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2(
)
# video_repl.full is a list of token IDs
repl_token_ids = torch.tensor(video_repl.full, device=device)
repl_token_ids = torch.tensor(video_repl.full)
# Get embedding token IDs for image context (use pre-tokenized version)
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device)
embed_token_ids = torch.tensor(self._img_context_token_ids)
# Create mask for video embedding positions
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)

View File

@@ -211,15 +211,12 @@ def merge_interleaved_embeddings(
# Scatter each modality to its positions
if video_embeds:
video_positions = is_video.nonzero(as_tuple=True)[0]
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
inputs_embeds[is_video] = torch.cat(video_embeds, dim=0)
if audio_embeds:
audio_positions = is_audio.nonzero(as_tuple=True)[0]
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
inputs_embeds[is_audio] = torch.cat(audio_embeds, dim=0)
if other_embeds:
other_mask = is_multimodal & ~is_video & ~is_audio
other_positions = other_mask.nonzero(as_tuple=True)[0]
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
inputs_embeds[other_mask] = torch.cat(other_embeds, dim=0)
return inputs_embeds
@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
input_ids_cpu = input_ids.cpu()
is_video = is_multimodal & (input_ids_cpu == video_token_id)
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()

View File

@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
# both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
input_ids_cpu = input_ids.cpu()
is_video = is_multimodal & (input_ids_cpu == video_token_id)
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()

View File

@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration(
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration(
select_token_id=self.is_multimodal_pruning_enabled,
)
repl_token_ids = torch.tensor(video_repl.full, device=device)
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
repl_token_ids = torch.tensor(video_repl.full)
embed_token_id = _cached_tensor(
self.config.video_token_id, repl_token_ids.device
)
is_video_embed = torch.isin(repl_token_ids, embed_token_id)
# Get text embeddings for indicator tokens (has only `visual_dim``).

View File

@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings(
input_dtype = inputs_embeds.dtype
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
# If is_multimodal is on CPU this avoids a D2H sync
inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings(
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
raise ValueError("Error during masked scatter operation") from e
raise ValueError("Error during index put operation") from e
return inputs_embeds