[Core] Simplify multimodal masking (#34246)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol):
|
||||
# to ensure that any external configuration requiring offset tracking,
|
||||
# e.g., LoRA, are applied correctly regardless of whether or not
|
||||
# we have multimodal tokens.
|
||||
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
|
||||
in_vocab_ids = input_ids.masked_fill(
|
||||
is_multimodal.to(device=input_ids.device, non_blocking=True), 0
|
||||
)
|
||||
return embed_input_ids(in_vocab_ids)
|
||||
|
||||
return embed_input_ids(input_ids)
|
||||
|
||||
@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2(
|
||||
These embeddings will replace the placeholder embeddings to create
|
||||
input_embeds for the LLM.
|
||||
"""
|
||||
device = video_embeddings.device
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
|
||||
# Generate video replacement token IDs using get_video_repl
|
||||
@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2(
|
||||
)
|
||||
|
||||
# video_repl.full is a list of token IDs
|
||||
repl_token_ids = torch.tensor(video_repl.full, device=device)
|
||||
repl_token_ids = torch.tensor(video_repl.full)
|
||||
|
||||
# Get embedding token IDs for image context (use pre-tokenized version)
|
||||
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device)
|
||||
embed_token_ids = torch.tensor(self._img_context_token_ids)
|
||||
|
||||
# Create mask for video embedding positions
|
||||
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
|
||||
|
||||
@@ -211,15 +211,12 @@ def merge_interleaved_embeddings(
|
||||
|
||||
# Scatter each modality to its positions
|
||||
if video_embeds:
|
||||
video_positions = is_video.nonzero(as_tuple=True)[0]
|
||||
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
|
||||
inputs_embeds[is_video] = torch.cat(video_embeds, dim=0)
|
||||
if audio_embeds:
|
||||
audio_positions = is_audio.nonzero(as_tuple=True)[0]
|
||||
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
|
||||
inputs_embeds[is_audio] = torch.cat(audio_embeds, dim=0)
|
||||
if other_embeds:
|
||||
other_mask = is_multimodal & ~is_video & ~is_audio
|
||||
other_positions = other_mask.nonzero(as_tuple=True)[0]
|
||||
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
|
||||
inputs_embeds[other_mask] = torch.cat(other_embeds, dim=0)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
video_token_id = self.config.video_token_index
|
||||
audio_token_id = self.config.audio_token_index
|
||||
|
||||
is_video = is_multimodal & (input_ids == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids == audio_token_id)
|
||||
input_ids_cpu = input_ids.cpu()
|
||||
is_video = is_multimodal & (input_ids_cpu == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
|
||||
|
||||
num_video = is_video.sum().item()
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
# both the deepstack path and the final embedding merge.
|
||||
video_token_id = self.config.video_token_id
|
||||
audio_token_id = self.config.audio_token_id
|
||||
is_video = is_multimodal & (input_ids == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids == audio_token_id)
|
||||
input_ids_cpu = input_ids.cpu()
|
||||
is_video = is_multimodal & (input_ids_cpu == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
|
||||
num_video = is_video.sum().item()
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
|
||||
@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
These embeddings will replace the placeholder embeddings to create
|
||||
input_embeds for the LLM.
|
||||
"""
|
||||
device = video_embeddings.device
|
||||
|
||||
# Generate video replacement token IDs using get_video_repl
|
||||
# This tokenizes each frame separator independently, then uses pre-tokenized
|
||||
@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration(
|
||||
select_token_id=self.is_multimodal_pruning_enabled,
|
||||
)
|
||||
|
||||
repl_token_ids = torch.tensor(video_repl.full, device=device)
|
||||
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
|
||||
repl_token_ids = torch.tensor(video_repl.full)
|
||||
embed_token_id = _cached_tensor(
|
||||
self.config.video_token_id, repl_token_ids.device
|
||||
)
|
||||
is_video_embed = torch.isin(repl_token_ids, embed_token_id)
|
||||
|
||||
# Get text embeddings for indicator tokens (has only `visual_dim``).
|
||||
|
||||
@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings(
|
||||
input_dtype = inputs_embeds.dtype
|
||||
|
||||
try:
|
||||
# For debugging
|
||||
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
|
||||
|
||||
# NOTE: This can avoid D2H sync (#22105), but fails to
|
||||
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
|
||||
inputs_embeds.masked_scatter_(
|
||||
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
|
||||
)
|
||||
# If is_multimodal is on CPU this avoids a D2H sync
|
||||
inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
|
||||
except RuntimeError as e:
|
||||
num_actual_tokens = len(mm_embeds_flat)
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings(
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders"
|
||||
) from e
|
||||
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
raise ValueError("Error during index put operation") from e
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user