Increase Flexibility for OOV Multimodal Token Handling (#34858)
Signed-off-by: Alex Brooks <albrooks@redhat.com>
This commit is contained in:
@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None,
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = super()._embed_text_input_ids(
|
||||
input_ids,
|
||||
embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
|
||||
@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
self._is_text_input = (
|
||||
multimodal_embeddings is None or len(multimodal_embeddings) == 0
|
||||
@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Embed input IDs with optional multimodal embeddings."""
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration(
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.configure_mm_token_handling(
|
||||
vocab_size=config.text_config.vocab_size,
|
||||
mm_token_ids=[config.image_token_index],
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = SiglipVisionModel(
|
||||
config.vision_config,
|
||||
@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# Early return for text-only inference (no multimodal data)
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
# them here, as the model forward has only access to the input_embeds.
|
||||
@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
|
||||
self.quant_config = quant_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
# Check for OOV tokens to see if offsets need to be preserved
|
||||
self.configure_mm_token_handling(
|
||||
vocab_size=config.text_config.vocab_size,
|
||||
mm_token_ids=[config.audio_token_index],
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
# The language model is typically a Granite LLM
|
||||
self.language_model = init_vllm_registered_model(
|
||||
@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol):
|
||||
Set internally by `_mark_tower_model`.
|
||||
"""
|
||||
|
||||
_has_oov_mm_tokens: bool = False
|
||||
"""
|
||||
In general, this should be set at init time by invoking
|
||||
`configure_mm_token_handling` models & passing all potentially
|
||||
OOV multimodal tokens.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
"""
|
||||
@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def configure_mm_token_handling(self, vocab_size: int, mm_token_ids: list[int]):
|
||||
"""Check if any multimodal tokens are out of vocabulary. If so, we will
|
||||
explicitly mask all multimodal tokens out when computing text embeddings,
|
||||
since the multimodal embeddings will be scattered over the results.
|
||||
"""
|
||||
self._has_oov_mm_tokens = any(tok_id >= vocab_size for tok_id in mm_token_ids)
|
||||
logger.info(
|
||||
"Contains out of vocabulary multimodal tokens? %s",
|
||||
self._has_oov_mm_tokens,
|
||||
)
|
||||
|
||||
def get_language_model(self) -> VllmModel:
|
||||
"""
|
||||
Returns the underlying language model used for text generation.
|
||||
@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol):
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
*,
|
||||
is_multimodal: torch.Tensor,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor: ...
|
||||
|
||||
def _embed_text_input_ids(
|
||||
@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol):
|
||||
embed_input_ids: Callable[[Tensor], Tensor],
|
||||
*,
|
||||
is_multimodal: Tensor | None,
|
||||
handle_oov_mm_token: bool,
|
||||
) -> Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = embed_input_ids(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
if is_multimodal is not None and self._has_oov_mm_tokens:
|
||||
# Force all input IDs to be in vocab; we do this instead of squeezing
|
||||
# to ensure that any external configuration requiring offset tracking,
|
||||
# e.g., LoRA, are applied correctly regardless of whether or not
|
||||
# we have multimodal tokens.
|
||||
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
|
||||
return embed_input_ids(in_vocab_ids)
|
||||
|
||||
return embed_input_ids(input_ids)
|
||||
|
||||
@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol):
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `embed_input_ids` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
NOTE: If this model has multimodal tokens that are of vocabulary
|
||||
(i.e., self._has_oov_mm_tokens=True), the input_ids will be copied
|
||||
and masked to 0 during the forward pass for the text embeddings.
|
||||
"""
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
# Get text embeddings first; multimodal embeddings will clobber
|
||||
# any invalid contents in the indices of multimodal embeddings
|
||||
# for the in vocabulary and out of vocabulary case.
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration(
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.configure_mm_token_handling(
|
||||
vocab_size=config.text_config.vocab_size,
|
||||
mm_token_ids=[config.image_token_index],
|
||||
)
|
||||
|
||||
# NOTE: These are special cases for Pixtral-12B in the HF-format
|
||||
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
|
||||
if (
|
||||
|
||||
@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.configure_mm_token_handling(
|
||||
vocab_size=config.text_config.vocab_size,
|
||||
mm_token_ids=[config.image_token_index],
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.embed_tokens,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
if len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# Check for audio-in-video: interleaved video and audio tokens
|
||||
# in the multimodal region. Only use the interleaved path when
|
||||
# needed; otherwise fall back to the default parent implementation.
|
||||
@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
return merge_interleaved_embeddings(
|
||||
inputs_embeds,
|
||||
@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.language_model.embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.model.embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.language_model.embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.language_model.embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -2301,13 +2301,11 @@ class Qwen3VLForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.language_model.embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
|
||||
@@ -1184,13 +1184,11 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None,
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = super()._embed_text_input_ids(
|
||||
input_ids,
|
||||
embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
# NOTE: inputs_embeds in model runner has size text_config.projection_size
|
||||
@@ -1219,7 +1217,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
self._is_text_input = (
|
||||
multimodal_embeddings is None or len(multimodal_embeddings) == 0
|
||||
@@ -1232,7 +1229,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
@@ -877,7 +877,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
@@ -890,7 +889,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -937,7 +937,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
@@ -945,6 +944,19 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
# NOTE: This behavior is consistent with the previous OOV handling,
|
||||
# but does not currently handle the start/stop toks around the
|
||||
# image features (<patch_start> <patch_end> <im_start> <im_end>)
|
||||
# See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
|
||||
#
|
||||
# If this becomes an issue or we refactor to handle this using the
|
||||
# processor info in the future, it would probably be best to handle
|
||||
# those too.
|
||||
self.configure_mm_token_handling(
|
||||
self.config.text_config.vocab_size,
|
||||
[self.config.image_token_id],
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_model = Step3VisionTransformer(
|
||||
config.vision_config,
|
||||
@@ -1080,8 +1092,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -1091,7 +1101,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -265,7 +265,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# We do not really use any input tokens and therefore no embeddings
|
||||
# to be calculated. However, due to the mandatory token ids in
|
||||
|
||||
@@ -551,6 +551,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
self.multi_modal_config = multimodal_config
|
||||
assert self.multi_modal_config
|
||||
|
||||
self.configure_mm_token_handling(
|
||||
self.config.vocab_size,
|
||||
[self.config.audio_token_index],
|
||||
)
|
||||
|
||||
self.secondary_weights = []
|
||||
if config.audio_model_id is not None:
|
||||
# this prefix is not for initialization, but for loading weights
|
||||
@@ -707,8 +712,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -718,7 +721,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -298,7 +298,6 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Pass post-conv embeddings directly as input.
|
||||
|
||||
|
||||
@@ -996,7 +996,6 @@ class WhisperForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# This method just returns the decoder sequence embeddings since
|
||||
# Whisper does not have encoder text tokens.
|
||||
|
||||
Reference in New Issue
Block a user