Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -346,6 +346,22 @@ class InputPreprocessor:
|
||||
) -> EmbedsInputs:
|
||||
return self._process_embeds(parsed_content)
|
||||
|
||||
def _truncate_inputs(
|
||||
self,
|
||||
inputs: list[int],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
|
||||
|
||||
if not tokenization_kwargs or "truncation" not in \
|
||||
tokenization_kwargs or self.tokenizer is None:
|
||||
return inputs
|
||||
|
||||
max_length = tokenization_kwargs["max_length"]
|
||||
|
||||
if self.tokenizer.truncation_side == "left":
|
||||
return inputs[-max_length:]
|
||||
else:
|
||||
return inputs[:max_length]
|
||||
|
||||
def _process_tokens(
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
@@ -354,7 +370,8 @@ class InputPreprocessor:
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
@@ -382,7 +399,8 @@ class InputPreprocessor:
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
|
||||
Reference in New Issue
Block a user