Re-enable the 80 char line width limit (#3305)
This commit is contained in:
@@ -5,9 +5,12 @@ import time
|
||||
import sys
|
||||
import pytest
|
||||
import requests
|
||||
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
|
||||
# using Ray for overall ease of process management, parallel requests,
|
||||
# and debugging.
|
||||
import ray
|
||||
import openai # use the official client for correctness check
|
||||
from huggingface_hub import snapshot_download # downloading lora to test lora requests
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
@@ -17,8 +20,11 @@ import re
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
||||
# generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
TEST_SCHEMA = {
|
||||
"type": "object",
|
||||
@@ -59,8 +65,8 @@ TEST_SCHEMA = {
|
||||
"required": ["name", "age", "skills", "work history"]
|
||||
}
|
||||
|
||||
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
|
||||
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
TEST_CHOICE = [
|
||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
|
||||
@@ -120,8 +126,9 @@ def server(zephyr_lora_files):
|
||||
server_runner = ServerRunner.remote([
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16", # use half precision for speed and memory savings in CI environment
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
@@ -392,7 +399,8 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
||||
# for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
@@ -469,8 +477,8 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=
|
||||
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {TEST_SCHEMA}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
@@ -489,9 +497,11 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Give an example JSON for an employee profile that " + \
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
}]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
|
||||
Reference in New Issue
Block a user