[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -71,7 +71,7 @@ TEST_CHOICE = [
|
||||
"Swift", "Kotlin"
|
||||
]
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
pytestmark = pytest.mark.openai
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -91,6 +91,8 @@ def server(zephyr_lora_files):
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.75",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
@@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files):
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.75",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
@@ -136,6 +140,7 @@ def client():
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
@@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
@@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
@@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert choice.logprobs.top_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 5
|
||||
@@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
@@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
@@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
@@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
@@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
assert ip1 != ip2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
@@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
assert completion.choices[i].text in TEST_CHOICE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
@@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
@@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
@@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
for token, logprob in token_dict.items())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
for _ in range(2):
|
||||
resp = await client.chat.completions.create(
|
||||
@@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
assert loaded == {"result": 2}, loaded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
await client.chat.completions.create(
|
||||
@@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||
assert "extra_forbidden" in exc_info.value.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||
resp = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
@@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||
assert content == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_role(server, client: openai.AsyncOpenAI):
|
||||
# Not sure how the model handles custom roles so we just check that
|
||||
# both string and complex message content are handled in the same way
|
||||
@@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||
simple_sql_grammar = """
|
||||
start: select_statement
|
||||
@@ -847,6 +873,7 @@ number: "1" | "2"
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
@@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_seed(server, client: openai.AsyncOpenAI):
|
||||
for seed in [
|
||||
torch.iinfo(torch.long).min - 1,
|
||||
@@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
|
||||
or "less_than_equal" in exc_info.value.message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
@@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||
assert embeddings.usage.total_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
|
||||
Reference in New Issue
Block a user