[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Nick Hill
2024-07-18 00:13:30 -07:00
committed by GitHub
parent 8a74c68bd1
commit e2fbaee725
16 changed files with 267 additions and 186 deletions

View File

@@ -5,13 +5,15 @@ import requests
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def server():
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@@ -21,12 +23,25 @@ def server():
"--enforce-eager",
"--max-num-seqs",
"128",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def tokenizer_name(model_name: str,
zephyr_lora_added_tokens_files: str): # noqa: F811
return zephyr_lora_added_tokens_files if (
model_name == "zephyr-lora2") else model_name
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@@ -34,16 +49,18 @@ def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str):
model_name: str, tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_special in [False, True]:
prompt = "This is a test prompt."
prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
@@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_generation in [False, True]:
for add_special in [False, True]:
@@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
"content": "Nice to meet you!"
}, {
"role": "user",
"content": "Can I ask a question?"
"content": "Can I ask a question? vllm1"
}]
prompt = tokenizer.apply_chat_template(
@@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
prompt = "This is a test prompt."
prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(f"CALLING {base_url} FOR {model_name}")
response = requests.post(base_url + "/detokenize",
json={
"model": model_name,