Truncation control for embedding models (#14776)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Gabriel Marinho
2025-04-29 22:24:57 -03:00
committed by GitHub
parent 4055130a85
commit 1c2bc7ead0
21 changed files with 333 additions and 71 deletions

View File

@@ -2,7 +2,7 @@
import asyncio
from collections.abc import Mapping
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
@@ -183,18 +183,21 @@ class InputPreprocessor:
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
add_special_tokens = None
if tokenization_kwargs is None:
tokenization_kwargs = {}
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
tokenization_kwargs["add_special_tokens"] = False
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
@@ -203,25 +206,27 @@ class InputPreprocessor:
return tokenizer.encode(prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
**tokenization_kwargs)
async def _tokenize_prompt_async(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
add_special_tokens = None
if tokenization_kwargs is None:
tokenization_kwargs = {}
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
tokenization_kwargs["add_special_tokens"] = False
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
def _process_multimodal(
self,
@@ -281,6 +286,7 @@ class InputPreprocessor:
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
@@ -304,6 +310,7 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
@@ -352,6 +359,7 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
@@ -364,6 +372,7 @@ class InputPreprocessor:
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
@@ -375,6 +384,7 @@ class InputPreprocessor:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
@@ -517,6 +527,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
@@ -553,7 +564,9 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt):
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"])
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
@@ -565,7 +578,10 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = self._prompt_to_llm_inputs(prompt)
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
@@ -581,6 +597,7 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs: SingletonInputs
@@ -588,13 +605,18 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"])
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task
decoder_inputs = None
else:
decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
@@ -606,7 +628,10 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(prompt)
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
@@ -638,6 +663,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@@ -660,6 +686,7 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
@@ -672,6 +699,7 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@@ -679,6 +707,7 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
@@ -691,6 +720,7 @@ class InputPreprocessor:
def preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@@ -711,6 +741,7 @@ class InputPreprocessor:
# Decoder-only operation
return self._process_decoder_only_prompt(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,
@@ -719,6 +750,7 @@ class InputPreprocessor:
async def preprocess_async(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@@ -739,6 +771,7 @@ class InputPreprocessor:
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,