[Model] Whisper model implementation (#11280)
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
@@ -21,6 +21,25 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
MistralTokenizer]
|
||||
|
||||
|
||||
def encode_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
||||
"""
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
return tokenizer.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
elif add_special_tokens is not None:
|
||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
"""Get tokenizer with cached properties.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user