[Misc][LoRA] Improve the readability of LoRA error messages (#12102)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-01-17 19:32:28 +08:00
committed by GitHub
parent 69d765f5a5
commit 07934cc237
10 changed files with 243 additions and 114 deletions

View File

@@ -17,6 +17,33 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
BADREQUEST_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]
@pytest.fixture(scope="module")
def zephyr_lora_files():
@@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"
@pytest.mark.parametrize("test_name,config_change,expected_error",
BADREQUEST_CASES)
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
zephyr_lora_files, test_name: str,
config_change: dict,
expected_error: str):
# Create test directory
test_dir = tmp_path / test_name
# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)
# Copy adapter files
shutil.copytree(zephyr_lora_files, test_dir)
with open(invalid_rank / "adapter_config.json") as f:
# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)
print(adapter_config)
# assert False
# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)
with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
# Test loading the adapter
with pytest.raises(openai.BadRequestError, match=expected_error):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
"lora_name": test_name,
"lora_path": str(test_dir)
})