[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -28,6 +28,8 @@ PA_NAME = "swapnilbp/llama_tweet_ptune"
|
||||
# need to change to match the prompt adapter
|
||||
PA_NUM_VIRTUAL_TOKENS = 8
|
||||
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
@@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
@@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_regex):
|
||||
@@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
@@ -761,8 +760,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, sample_regex):
|
||||
|
||||
Reference in New Issue
Block a user