[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7489ec0bab
commit
34120f5acd
@@ -16,13 +16,31 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
NGRAM_SPEC_CONFIG = {
|
||||
"model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 1,
|
||||
}
|
||||
|
||||
EAGLE_SPEC_CONFIG = {
|
||||
"method": "eagle",
|
||||
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
"num_speculative_tokens": 5,
|
||||
}
|
||||
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||
#FIXME: This test is flaky on CI thus disabled
|
||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
||||
NGRAM_SPEC_CONFIG),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto",
|
||||
EAGLE_SPEC_CONFIG)
|
||||
]
|
||||
|
||||
PARAMS_MODELS_TOKENIZER_MODE = [
|
||||
@@ -45,8 +63,9 @@ class CarDescription(BaseModel):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode",
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
|
||||
def test_structured_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
@@ -58,6 +77,7 @@ def test_structured_output(
|
||||
guided_decoding_backend: str,
|
||||
tokenizer_mode: str,
|
||||
model_name: str,
|
||||
speculative_config: dict[str, Any],
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -71,7 +91,8 @@ def test_structured_output(
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode)
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
#
|
||||
# Test 1: Generate JSON output based on a provided schema
|
||||
|
||||
Reference in New Issue
Block a user