Test Prompt Embeds/LoRA compatibility and Enable LoRA Support for OPT Models (#25717)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
@@ -16,13 +17,15 @@ from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
LORA_SERVING_MODEL_NAME = "opt125m-lora"
|
||||
|
||||
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args() -> list[str]:
|
||||
return [
|
||||
@pytest.fixture(scope="module", params=["use-lora"])
|
||||
def default_server_args(request: pytest.FixtureRequest,
|
||||
opt125_lora_files: str) -> list[str]:
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
@@ -35,6 +38,25 @@ def default_server_args() -> list[str]:
|
||||
"--enable-prompt-embeds",
|
||||
]
|
||||
|
||||
if request.param == "use-lora":
|
||||
lora_module_1 = {
|
||||
"name": LORA_SERVING_MODEL_NAME,
|
||||
"path": opt125_lora_files,
|
||||
"base_model_name": MODEL_NAME
|
||||
}
|
||||
|
||||
args.extend([
|
||||
"--enable-lora",
|
||||
"--lora-module",
|
||||
json.dumps(lora_module_1),
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
])
|
||||
|
||||
return args
|
||||
|
||||
|
||||
EXAMPLE_PROMPTS = [
|
||||
"Hello, my name is",
|
||||
@@ -74,7 +96,7 @@ async def client_with_prompt_embeds(server_with_prompt_embeds):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
|
||||
async def test_completions_with_prompt_embeds(
|
||||
example_prompt_embeds,
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI,
|
||||
@@ -179,7 +201,7 @@ async def test_completions_with_prompt_embeds(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
|
||||
async def test_completions_errors_with_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||
# Test error case: invalid prompt_embeds
|
||||
@@ -194,7 +216,7 @@ async def test_completions_errors_with_prompt_embeds(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
|
||||
async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
example_prompt_embeds,
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI,
|
||||
|
||||
Reference in New Issue
Block a user