Truncation control for embedding models (#14776)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -55,6 +55,8 @@ def encode_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
@@ -64,10 +66,18 @@ def encode_tokens(
|
||||
:code:`add_special_tokens=None` means to use the backend's default
|
||||
settings.
|
||||
"""
|
||||
if add_special_tokens is not None:
|
||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||
|
||||
return tokenizer.encode(text)
|
||||
kw_args: dict[str, Any] = {}
|
||||
if max_length is not None:
|
||||
kw_args["max_length"] = max_length
|
||||
|
||||
if truncation is not None:
|
||||
kw_args["truncation"] = truncation
|
||||
|
||||
if add_special_tokens is not None:
|
||||
kw_args["add_special_tokens"] = add_special_tokens
|
||||
|
||||
return tokenizer.encode(text, **kw_args)
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
|
||||
Reference in New Issue
Block a user