Increase Flexibility for OOV Multimodal Token Handling (#34858)
Signed-off-by: Alex Brooks <albrooks@redhat.com>
This commit is contained in:
@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
|
||||
self.quant_config = quant_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
# Check for OOV tokens to see if offsets need to be preserved
|
||||
self.configure_mm_token_handling(
|
||||
vocab_size=config.text_config.vocab_size,
|
||||
mm_token_ids=[config.audio_token_index],
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
# The language model is typically a Granite LLM
|
||||
self.language_model = init_vllm_registered_model(
|
||||
@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user