Truncation control for embedding models (#14776)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Gabriel Marinho
2025-04-29 22:24:57 -03:00
committed by GitHub
parent 4055130a85
commit 1c2bc7ead0
21 changed files with 333 additions and 71 deletions

View File

@@ -55,6 +55,8 @@ def encode_tokens(
tokenizer: AnyTokenizer,
text: str,
*,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
"""
@@ -64,10 +66,18 @@ def encode_tokens(
:code:`add_special_tokens=None` means to use the backend's default
settings.
"""
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)
kw_args: dict[str, Any] = {}
if max_length is not None:
kw_args["max_length"] = max_length
if truncation is not None:
kw_args["truncation"] = truncation
if add_special_tokens is not None:
kw_args["add_special_tokens"] = add_special_tokens
return tokenizer.encode(text, **kw_args)
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: