Increase Flexibility for OOV Multimodal Token Handling (#34858)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
This commit is contained in:
Alex Brooks
2026-03-08 21:30:49 -06:00
committed by GitHub
parent 90512b2e8b
commit bd2659a566
28 changed files with 79 additions and 77 deletions

View File

@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
self._is_text_input = (
multimodal_embeddings is None or len(multimodal_embeddings) == 0
@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:

View File

@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""Embed input IDs with optional multimodal embeddings."""
if multimodal_embeddings is None or is_multimodal is None:
@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)

View File

@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration(
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# Early return for text-only inference (no multimodal data)
if multimodal_embeddings is None or is_multimodal is None:
@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
self.quant_config = quant_config
self.cache_config = cache_config
# Check for OOV tokens to see if offsets need to be preserved
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.audio_token_index],
)
with self._mark_language_model(vllm_config):
# The language model is typically a Granite LLM
self.language_model = init_vllm_registered_model(
@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol):
Set internally by `_mark_tower_model`.
"""
_has_oov_mm_tokens: bool = False
"""
In general, this should be set at init time by invoking
`configure_mm_token_handling` models & passing all potentially
OOV multimodal tokens.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""
@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol):
"""
...
def configure_mm_token_handling(self, vocab_size: int, mm_token_ids: list[int]):
"""Check if any multimodal tokens are out of vocabulary. If so, we will
explicitly mask all multimodal tokens out when computing text embeddings,
since the multimodal embeddings will be scattered over the results.
"""
self._has_oov_mm_tokens = any(tok_id >= vocab_size for tok_id in mm_token_ids)
logger.info(
"Contains out of vocabulary multimodal tokens? %s",
self._has_oov_mm_tokens,
)
def get_language_model(self) -> VllmModel:
"""
Returns the underlying language model used for text generation.
@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor: ...
def _embed_text_input_ids(
@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol):
embed_input_ids: Callable[[Tensor], Tensor],
*,
is_multimodal: Tensor | None,
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = embed_input_ids(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
if is_multimodal is not None and self._has_oov_mm_tokens:
# Force all input IDs to be in vocab; we do this instead of squeezing
# to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
# we have multimodal tokens.
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
return embed_input_ids(in_vocab_ids)
return embed_input_ids(input_ids)
@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> Tensor:
"""
Apply token embeddings to `input_ids`.
@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol):
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `embed_input_ids` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
NOTE: If this model has multimodal tokens that are of vocabulary
(i.e., self._has_oov_mm_tokens=True), the input_ids will be copied
and masked to 0 during the forward pass for the text embeddings.
"""
from .utils import _merge_multimodal_embeddings
# Get text embeddings first; multimodal embeddings will clobber
# any invalid contents in the indices of multimodal embeddings
# for the in vocabulary and out of vocabulary case.
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (

View File

@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.config = config
self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.embed_tokens,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region. Only use the interleaved path when
# needed; otherwise fall back to the default parent implementation.
@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return merge_interleaved_embeddings(
inputs_embeds,
@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -2301,13 +2301,11 @@ class Qwen3VLForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:

View File

@@ -1184,13 +1184,11 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_size
@@ -1219,7 +1217,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
self._is_text_input = (
multimodal_embeddings is None or len(multimodal_embeddings) == 0
@@ -1232,7 +1229,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:

View File

@@ -877,7 +877,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
@@ -890,7 +889,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -937,7 +937,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
@@ -945,6 +944,19 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
# NOTE: This behavior is consistent with the previous OOV handling,
# but does not currently handle the start/stop toks around the
# image features (<patch_start> <patch_end> <im_start> <im_end>)
# See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
#
# If this becomes an issue or we refactor to handle this using the
# processor info in the future, it would probably be best to handle
# those too.
self.configure_mm_token_handling(
self.config.text_config.vocab_size,
[self.config.image_token_id],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer(
config.vision_config,
@@ -1080,8 +1092,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
@@ -1091,7 +1101,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -265,7 +265,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in

View File

@@ -551,6 +551,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self.multi_modal_config = multimodal_config
assert self.multi_modal_config
self.configure_mm_token_handling(
self.config.vocab_size,
[self.config.audio_token_index],
)
self.secondary_weights = []
if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights
@@ -707,8 +712,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
@@ -718,7 +721,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -298,7 +298,6 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Pass post-conv embeddings directly as input.

View File

@@ -996,7 +996,6 @@ class WhisperForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.