[V1][Core] Support MistralTokenizer for Structured Output (#14625)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
@@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_completion(monkeypatch, sample_json_schema,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_json_completion(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
@@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
|
||||
def test_guided_json_object(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
@@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_json_unsupported_schema(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
unsupported_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
@@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_grammar_ebnf(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_sql_ebnf: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
@@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_grammar_lark(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_sql_lark: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
@@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_ebnf_invalid(monkeypatch,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_grammar_ebnf_invalid(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
@@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
|
||||
def test_guided_regex(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_regex: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
@@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_choice_completion(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_guided_choice: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
|
||||
Reference in New Issue
Block a user