[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin
2024-12-30 22:43:54 -05:00
committed by GitHub
parent a2a40bcd0d
commit 74fa1d123c
4 changed files with 17 additions and 13 deletions

View File

@@ -28,6 +28,8 @@ PA_NAME = "swapnilbp/llama_tweet_ptune"
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@pytest.fixture(scope="module")
def zephyr_lora_files():
@@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_json_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
@@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_regex):
@@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
@@ -761,8 +760,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema, sample_regex):