Truncation control for embedding models (#14776)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@@ -183,18 +183,21 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Apply the model's tokenizer to a text prompt, returning the
|
||||
corresponding token IDs.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
add_special_tokens = None
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
add_special_tokens = False
|
||||
tokenization_kwargs["add_special_tokens"] = False
|
||||
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
@@ -203,25 +206,27 @@ class InputPreprocessor:
|
||||
|
||||
return tokenizer.encode(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
**tokenization_kwargs)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
add_special_tokens = None
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
add_special_tokens = False
|
||||
return await tokenizer.encode_async(
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
tokenization_kwargs["add_special_tokens"] = False
|
||||
return await tokenizer.encode_async(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
**tokenization_kwargs)
|
||||
|
||||
def _process_multimodal(
|
||||
self,
|
||||
@@ -281,6 +286,7 @@ class InputPreprocessor:
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
@@ -304,6 +310,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@@ -352,6 +359,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@@ -364,6 +372,7 @@ class InputPreprocessor:
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
@@ -375,6 +384,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@@ -517,6 +527,7 @@ class InputPreprocessor:
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@@ -553,7 +564,9 @@ class InputPreprocessor:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt["encoder_prompt"])
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
else:
|
||||
@@ -565,7 +578,10 @@ class InputPreprocessor:
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = self._prompt_to_llm_inputs(prompt)
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
@@ -581,6 +597,7 @@ class InputPreprocessor:
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_inputs: SingletonInputs
|
||||
@@ -588,13 +605,18 @@ class InputPreprocessor:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_task = self._prompt_to_llm_inputs_async(
|
||||
prompt["encoder_prompt"])
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
encoder_inputs = await encoder_task
|
||||
decoder_inputs = None
|
||||
else:
|
||||
decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
|
||||
decoder_task = self._prompt_to_llm_inputs_async(
|
||||
decoder_input,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
@@ -606,7 +628,10 @@ class InputPreprocessor:
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = await self._prompt_to_llm_inputs_async(prompt)
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
@@ -638,6 +663,7 @@ class InputPreprocessor:
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@@ -660,6 +686,7 @@ class InputPreprocessor:
|
||||
|
||||
prompt_comps = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
@@ -672,6 +699,7 @@ class InputPreprocessor:
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@@ -679,6 +707,7 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
@@ -691,6 +720,7 @@ class InputPreprocessor:
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@@ -711,6 +741,7 @@ class InputPreprocessor:
|
||||
# Decoder-only operation
|
||||
return self._process_decoder_only_prompt(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
@@ -719,6 +750,7 @@ class InputPreprocessor:
|
||||
async def preprocess_async(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@@ -739,6 +771,7 @@ class InputPreprocessor:
|
||||
# Decoder-only operation
|
||||
return await self._process_decoder_only_prompt_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
|
||||
Reference in New Issue
Block a user