[v0][structured output] Support reasoning output (#12955)

Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
Ce Gao
2025-03-03 03:49:42 +08:00
committed by GitHub
parent bc6ccb9878
commit bf33700ecd
16 changed files with 400 additions and 76 deletions

View File

@@ -16,17 +16,33 @@ from vllm.sampling_params import GuidedDecodingParams
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
def test_guided_logits_processors(sample_regex, sample_json_schema):
# Initialize the tokenizer for the model here to avoid repeated loading
@pytest.fixture(scope="module")
def zephyr_7B_tokenzer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
regex_LP = RegexLogitsProcessor(sample_regex,
zephyr_7B_tokenzer,
reasoner=None)
json_LP = JSONLogitsProcessor(sample_json_schema,
tokenizer,
whitespace_pattern=None)
zephyr_7B_tokenzer,
whitespace_pattern=None,
reasoner=None)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
@@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema):
sample_json_schema,
zephyr_7B_tokenzer):
config = ModelConfig(
MODEL_NAME,
@@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(
regex_request, tokenizer, config) if is_local else \
regex_request, zephyr_7B_tokenzer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, tokenizer, config)
regex_request, zephyr_7B_tokenzer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer, config)
json_request, zephyr_7B_tokenzer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
@pytest.mark.parametrize("is_local", [True, False])
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
async def test_guided_logits_processor_with_reasoning(
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
sample_json_schema, deepseek_r1_qwen_tokenizer):
config = ModelConfig(
REASONING_MODEL_NAME,
task="generate",
tokenizer=REASONING_MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process</think> Then")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)