[v0][structured output] Support reasoning output (#12955)
Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
@@ -16,17 +16,33 @@ from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
# Initialize the tokenizer for the model here to avoid repeated loading
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_7B_tokenzer():
|
||||
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
||||
sample_json_schema):
|
||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
|
||||
regex_LP = RegexLogitsProcessor(sample_regex,
|
||||
zephyr_7B_tokenzer,
|
||||
reasoner=None)
|
||||
json_LP = JSONLogitsProcessor(sample_json_schema,
|
||||
tokenizer,
|
||||
whitespace_pattern=None)
|
||||
zephyr_7B_tokenzer,
|
||||
whitespace_pattern=None,
|
||||
reasoner=None)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
tensor = torch.rand(32000)
|
||||
@@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
sample_regex,
|
||||
sample_json_schema):
|
||||
sample_json_schema,
|
||||
zephyr_7B_tokenzer):
|
||||
|
||||
config = ModelConfig(
|
||||
MODEL_NAME,
|
||||
@@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer, config) if is_local else \
|
||||
regex_request, zephyr_7B_tokenzer, config) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer, config)
|
||||
regex_request, zephyr_7B_tokenzer, config)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
json_request, tokenizer, config)
|
||||
json_request, zephyr_7B_tokenzer, config)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend",
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
|
||||
async def test_guided_logits_processor_with_reasoning(
|
||||
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
|
||||
sample_json_schema, deepseek_r1_qwen_tokenizer):
|
||||
|
||||
config = ModelConfig(
|
||||
REASONING_MODEL_NAME,
|
||||
task="generate",
|
||||
tokenizer=REASONING_MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}."
|
||||
"<think>here is the thinking process")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
|
||||
deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
# Thinking is over, so the tensor should change.
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process</think> Then")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
|
||||
Reference in New Issue
Block a user