Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,17 +6,16 @@ from copy import deepcopy
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
|
||||
def test_cached_tokenizer(model_id: str):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
|
||||
trust_remote_code=True)
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, trust_remote_code=True
|
||||
)
|
||||
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
|
||||
reference_tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": ["<SEP>"]})
|
||||
reference_tokenizer.add_special_tokens({"additional_special_tokens": ["<SEP>"]})
|
||||
|
||||
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
|
||||
_check_consistency(cached_tokenizer, reference_tokenizer)
|
||||
@@ -32,13 +31,13 @@ def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
|
||||
# Cached attributes
|
||||
assert target.all_special_ids == expected.all_special_ids
|
||||
assert target.all_special_tokens == expected.all_special_tokens
|
||||
assert (target.all_special_tokens_extended ==
|
||||
expected.all_special_tokens_extended)
|
||||
assert target.all_special_tokens_extended == expected.all_special_tokens_extended
|
||||
assert target.get_vocab() == expected.get_vocab()
|
||||
assert len(target) == len(expected)
|
||||
|
||||
# Other attributes
|
||||
assert getattr(target, "padding_side",
|
||||
None) == getattr(expected, "padding_side", None)
|
||||
assert getattr(target, "padding_side", None) == getattr(
|
||||
expected, "padding_side", None
|
||||
)
|
||||
|
||||
assert target.encode("prompt") == expected.encode("prompt")
|
||||
|
||||
@@ -5,15 +5,16 @@ from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
IncrementalDetokenizer,
|
||||
SlowIncrementalDetokenizer)
|
||||
from vllm.v1.engine.detokenizer import (
|
||||
FastIncrementalDetokenizer,
|
||||
IncrementalDetokenizer,
|
||||
SlowIncrementalDetokenizer,
|
||||
)
|
||||
|
||||
SPECIAL_TOKS_TRUTH = [
|
||||
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
|
||||
@@ -45,33 +46,35 @@ TOKENIZERS = [
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens: bool,
|
||||
starting_index: int,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
fast: Optional[bool] = None):
|
||||
|
||||
def _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens: bool,
|
||||
starting_index: int,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
fast: Optional[bool] = None,
|
||||
):
|
||||
prompt_token_ids = all_input_ids[:starting_index]
|
||||
|
||||
params = SamplingParams(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
request = EngineCoreRequest(request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None)
|
||||
request = EngineCoreRequest(
|
||||
request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
if fast is None:
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer, request)
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
|
||||
elif fast:
|
||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||
else:
|
||||
@@ -88,9 +91,11 @@ def _run_incremental_decode(tokenizer,
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(tokenizer_name):
|
||||
return (MistralTokenizer.from_pretrained(tokenizer_name)
|
||||
if "mistral" in tokenizer_name else
|
||||
AutoTokenizer.from_pretrained(tokenizer_name))
|
||||
return (
|
||||
MistralTokenizer.from_pretrained(tokenizer_name)
|
||||
if "mistral" in tokenizer_name
|
||||
else AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
|
||||
@@ -102,7 +107,8 @@ def tokenizer(tokenizer_name):
|
||||
"ပုံပြင်လေးပြောပြပါ",
|
||||
# Using "URGENCY" since "CY" has token id 130282
|
||||
"URGENCY🌶️",
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_mistral_edge_case(tokenizer, truth):
|
||||
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
|
||||
|
||||
@@ -115,7 +121,8 @@ def test_mistral_edge_case(tokenizer, truth):
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
starting_index=starting_index,
|
||||
)
|
||||
assert decoded_text == truth
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
|
||||
@@ -124,8 +131,10 @@ def test_mistral_edge_case(tokenizer, truth):
|
||||
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
if "mistral" in tokenizer_name:
|
||||
yield (
|
||||
True if request.param else
|
||||
pytest.skip("mistral doesn't support skip_special_tokens=False"))
|
||||
True
|
||||
if request.param
|
||||
else pytest.skip("mistral doesn't support skip_special_tokens=False")
|
||||
)
|
||||
else:
|
||||
yield bool(request.param)
|
||||
|
||||
@@ -136,8 +145,14 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
|
||||
@pytest.mark.parametrize("fast", (True, False))
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
spaces_between_special_tokens, fast):
|
||||
def test_decode_streaming(
|
||||
tokenizer,
|
||||
truth,
|
||||
with_prompt,
|
||||
skip_special_tokens,
|
||||
spaces_between_special_tokens,
|
||||
fast,
|
||||
):
|
||||
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
pytest.skip()
|
||||
|
||||
@@ -146,30 +161,35 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
|
||||
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fix up inconsistency in fast/slow tokenizer behaviour.
|
||||
tokenizer.add_special_tokens({
|
||||
"additional_special_tokens": [
|
||||
at for at in
|
||||
tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||
if at.special
|
||||
]
|
||||
})
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
at
|
||||
for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||
if at.special
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
|
||||
extra_decode_args = (
|
||||
{}
|
||||
if not isinstance(tokenizer, PreTrainedTokenizer)
|
||||
else {"spaces_between_special_tokens": spaces_between_special_tokens}
|
||||
)
|
||||
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if tokenizer.bos_token_id is not None:
|
||||
truth_tokens.insert(0, tokenizer.bos_token_id)
|
||||
truth_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
new_truth = tokenizer.decode(truth_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
new_truth = tokenizer.decode(
|
||||
truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
|
||||
)
|
||||
|
||||
if with_prompt:
|
||||
num_prompt_tokens = len(
|
||||
tokenizer(truth[:len(truth) // 2],
|
||||
add_special_tokens=False).input_ids)
|
||||
tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
|
||||
)
|
||||
if tokenizer.bos_token_id is not None:
|
||||
num_prompt_tokens += 1
|
||||
|
||||
@@ -177,11 +197,13 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
generated_input_ids = truth_tokens[num_prompt_tokens:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
starting_index = len(prompt_input_ids)
|
||||
prompt = tokenizer.decode(prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
prompt = tokenizer.decode(
|
||||
prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args,
|
||||
)
|
||||
|
||||
generated = new_truth[len(prompt):]
|
||||
generated = new_truth[len(prompt) :]
|
||||
else:
|
||||
generated = new_truth
|
||||
starting_index = 0
|
||||
@@ -193,7 +215,8 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
fast=fast)
|
||||
fast=fast,
|
||||
)
|
||||
|
||||
assert decoded_text == generated
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
@@ -206,11 +229,13 @@ def test_oov_decode(tokenizer, fast):
|
||||
pytest.skip()
|
||||
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer, [len(tokenizer)],
|
||||
tokenizer,
|
||||
[len(tokenizer)],
|
||||
skip_special_tokens=True,
|
||||
starting_index=0,
|
||||
spaces_between_special_tokens=True,
|
||||
fast=fast)
|
||||
fast=fast,
|
||||
)
|
||||
|
||||
assert decoded_text == ''
|
||||
assert decoded_text == ""
|
||||
assert out_ids == [len(tokenizer)]
|
||||
|
||||
@@ -13,6 +13,6 @@ TOKENIZER_NAMES = ["BAAI/bge-base-en"]
|
||||
def test_special_tokens(tokenizer_name: str, n_tokens: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||
|
||||
prompts = '[UNK]' * n_tokens
|
||||
prompts = "[UNK]" * n_tokens
|
||||
prompt_token_ids = tokenizer.encode(prompts)
|
||||
assert len(prompt_token_ids) == n_tokens + 2
|
||||
|
||||
@@ -5,6 +5,7 @@ This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
{meth}`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@@ -15,8 +16,7 @@ def test_get_llama3_eos_token():
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
generation_config = try_get_generation_config(model_name,
|
||||
trust_remote_code=False)
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128008, 128009]
|
||||
|
||||
@@ -27,7 +27,6 @@ def test_get_blip2_eos_token():
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
generation_config = try_get_generation_config(model_name,
|
||||
trust_remote_code=False)
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
||||
|
||||
@@ -2,187 +2,206 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import (AssistantMessage,
|
||||
ToolMessage,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AssistantMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import (Function,
|
||||
FunctionCall, Tool,
|
||||
ToolCall)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
Function,
|
||||
FunctionCall,
|
||||
Tool,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from vllm.transformers_utils.tokenizers.mistral import (
|
||||
make_mistral_chat_completion_request)
|
||||
make_mistral_chat_completion_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"openai_request,expected_mistral_request",
|
||||
[(
|
||||
{
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
},
|
||||
}],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What is the current local date and time?")
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_current_time",
|
||||
description="Fetch the current local date and time.",
|
||||
parameters={},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
{
|
||||
"messages":
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
"parameters": None,
|
||||
},
|
||||
}],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the current local date and time?")
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_current_time",
|
||||
description="Fetch the current local date and time.",
|
||||
parameters={},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)],
|
||||
)
|
||||
def test_make_mistral_chat_completion_request(openai_request,
|
||||
expected_mistral_request):
|
||||
actual_request = make_mistral_chat_completion_request(
|
||||
openai_request["messages"], openai_request["tools"])
|
||||
assert actual_request == expected_mistral_request
|
||||
|
||||
|
||||
# Tool use with list content and reasoning_content
|
||||
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
|
||||
{
|
||||
"messages": [
|
||||
[
|
||||
(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Paris?",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"assistant",
|
||||
"reasoning_content":
|
||||
None,
|
||||
"content":
|
||||
None,
|
||||
"tool_calls": [{
|
||||
"id": "call123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Rainy"
|
||||
}],
|
||||
"name": "get_weather",
|
||||
"tool_call_id": "call123",
|
||||
},
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather in a city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What's the weather in Paris?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="call123",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What is the current local date and time?")
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_current_time",
|
||||
description="Fetch the current local date and time.",
|
||||
parameters={},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content="Rainy",
|
||||
tool_call_id="call123",
|
||||
name="get_weather",
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Gets the current weather in a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
),
|
||||
(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
"parameters": None,
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)])
|
||||
def test_make_mistral_chat_completion_request_list_content(
|
||||
openai_request, expected_mistral_request):
|
||||
}
|
||||
],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What is the current local date and time?")
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_current_time",
|
||||
description="Fetch the current local date and time.",
|
||||
parameters={},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request):
|
||||
actual_request = make_mistral_chat_completion_request(
|
||||
openai_request["messages"], openai_request["tools"])
|
||||
openai_request["messages"], openai_request["tools"]
|
||||
)
|
||||
assert actual_request == expected_mistral_request
|
||||
|
||||
|
||||
# Tool use with list content and reasoning_content
|
||||
@pytest.mark.parametrize(
|
||||
"openai_request,expected_mistral_request",
|
||||
[
|
||||
(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Paris?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [{"type": "text", "text": "Rainy"}],
|
||||
"name": "get_weather",
|
||||
"tool_call_id": "call123",
|
||||
},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather in a city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What's the weather in Paris?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="call123",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content="Rainy",
|
||||
tool_call_id="call123",
|
||||
name="get_weather",
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Gets the current weather in a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_make_mistral_chat_completion_request_list_content(
|
||||
openai_request, expected_mistral_request
|
||||
):
|
||||
actual_request = make_mistral_chat_completion_request(
|
||||
openai_request["messages"], openai_request["tools"]
|
||||
)
|
||||
assert actual_request == expected_mistral_request
|
||||
|
||||
@@ -19,5 +19,5 @@ def test_tokenizer_revision(tokenizer_name: str):
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
||||
|
||||
# Assume that "never" branch always does not exist
|
||||
with pytest.raises(OSError, match='not a valid git identifier'):
|
||||
with pytest.raises(OSError, match="not a valid git identifier"):
|
||||
get_tokenizer(tokenizer_name, revision="never")
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||
TokenizerRegistry)
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class TestTokenizer(TokenizerBase):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
|
||||
return TestTokenizer()
|
||||
@@ -85,23 +83,23 @@ class TestTokenizer(TokenizerBase):
|
||||
) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
def encode(self, text: str, add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
**kwargs) -> list[int]:
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def decode(self,
|
||||
ids: Union[list[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
def decode(
|
||||
self, ids: Union[list[int], int], skip_special_tokens: bool = True
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
@@ -113,9 +111,9 @@ class TestTokenizer(TokenizerBase):
|
||||
|
||||
|
||||
def test_customized_tokenizer():
|
||||
TokenizerRegistry.register("test_tokenizer",
|
||||
"tests.tokenization.test_tokenizer_registry",
|
||||
"TestTokenizer")
|
||||
TokenizerRegistry.register(
|
||||
"test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer"
|
||||
)
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
|
||||
Reference in New Issue
Block a user