[V1] Add structural_tag support using xgrammar (#17085)

This commit is contained in:
Russell Bryant
2025-04-26 10:06:37 -04:00
committed by GitHub
parent c48334d405
commit f8acd01ff7
10 changed files with 270 additions and 15 deletions

View File

@@ -2,6 +2,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import re
import time
from argparse import Namespace
@@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict: Optional[bool] = None
class StructuralTag(OpenAIBaseModel):
begin: str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
alias="schema")
end: str
class StructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[StructuralTag]
triggers: list[str]
class ResponseFormat(OpenAIBaseModel):
# type must be "json_schema", "json_object" or "text"
# type must be "json_schema", "json_object", or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = False
@@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
response_format: Optional[AnyResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
@@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
structural_tag: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the structural tag schema."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
@@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
elif self.response_format.type == "structural_tag":
structural_tag = self.response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structural_tag = json.dumps(s_tag_obj)
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
@@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern,
structural_tag=self.structural_tag,
)
return SamplingParams.from_optional(
@@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
response_format: Optional[AnyResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"{'type': 'text' } is supported."),
description=(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,