Increase Flexibility for OOV Multimodal Token Handling (#34858)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
This commit is contained in:
Alex Brooks
2026-03-08 21:30:49 -06:00
committed by GitHub
parent 90512b2e8b
commit bd2659a566
28 changed files with 79 additions and 77 deletions

View File

@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
self.quant_config = quant_config
self.cache_config = cache_config
# Check for OOV tokens to see if offsets need to be preserved
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.audio_token_index],
)
with self._mark_language_model(vllm_config):
# The language model is typically a Granite LLM
self.language_model = init_vllm_registered_model(
@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(