feat - add a new endpoint get_tokenizer_info to provide tokenizer/chat-template information (#20575)
Signed-off-by: m-misiura <mmisiura@redhat.com>
This commit is contained in:
@@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--enable-tokenizer-info-endpoint",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@@ -283,3 +284,106 @@ async def test_detokenize(
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {"prompt": prompt}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenizer_info_basic(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
"""Test basic tokenizer info endpoint functionality."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
assert "tokenizer_class" in result
|
||||
assert isinstance(result["tokenizer_class"], str)
|
||||
assert result["tokenizer_class"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
|
||||
"""Test that the response matches expected schema types."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
field_types = {
|
||||
"add_bos_token": bool,
|
||||
"add_prefix_space": bool,
|
||||
"clean_up_tokenization_spaces": bool,
|
||||
"split_special_tokens": bool,
|
||||
"bos_token": str,
|
||||
"eos_token": str,
|
||||
"pad_token": str,
|
||||
"unk_token": str,
|
||||
"chat_template": str,
|
||||
"errors": str,
|
||||
"model_max_length": int,
|
||||
"additional_special_tokens": list,
|
||||
"added_tokens_decoder": dict,
|
||||
}
|
||||
for field, expected_type in field_types.items():
|
||||
if field in result and result[field] is not None:
|
||||
assert isinstance(
|
||||
result[field],
|
||||
expected_type), (f"{field} should be {expected_type.__name__}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_added_tokens_structure(
|
||||
server: RemoteOpenAIServer, ):
|
||||
"""Test added_tokens_decoder structure if present."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
added_tokens = result.get("added_tokens_decoder")
|
||||
if added_tokens:
|
||||
for token_id, token_info in added_tokens.items():
|
||||
assert isinstance(token_id, str), "Token IDs should be strings"
|
||||
assert isinstance(token_info, dict), "Token info should be a dict"
|
||||
assert "content" in token_info, "Token info should have content"
|
||||
assert "special" in token_info, (
|
||||
"Token info should have special flag")
|
||||
assert isinstance(token_info["special"],
|
||||
bool), ("Special flag should be boolean")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_consistency_with_tokenize(
|
||||
server: RemoteOpenAIServer, ):
|
||||
"""Test that tokenizer info is consistent with tokenization endpoint."""
|
||||
info_response = requests.get(server.url_for("tokenizer_info"))
|
||||
info_response.raise_for_status()
|
||||
info = info_response.json()
|
||||
tokenize_response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello world!"
|
||||
},
|
||||
)
|
||||
tokenize_response.raise_for_status()
|
||||
tokenize_result = tokenize_response.json()
|
||||
info_max_len = info.get("model_max_length")
|
||||
tokenize_max_len = tokenize_result.get("max_model_len")
|
||||
if info_max_len and tokenize_max_len:
|
||||
assert info_max_len >= tokenize_max_len, (
|
||||
"Info max length should be >= tokenize max length")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
|
||||
"""Test chat template is properly included."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
chat_template = result.get("chat_template")
|
||||
if chat_template:
|
||||
assert isinstance(chat_template,
|
||||
str), ("Chat template should be a string")
|
||||
assert chat_template.strip(), "Chat template should not be empty"
|
||||
Reference in New Issue
Block a user