Files
vllm/tests/lora/test_qwen35_densemoel_lora.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

133 lines
4.7 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
import vllm
import vllm.config
from vllm.lora.request import LoRARequest
from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "Qwen/Qwen3.5-4B"
PROMPT_TEMPLATE = """Write a SQL query for the given database.\nSchema:\nTables:\n - stadium(Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average)\n - singer(Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male)\n - concert(concert_ID, concert_Name, Theme, Stadium_ID, Year)\n - singer_in_concert(concert_ID, Singer_ID)\n\nQuestion:\n{query}""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'",
"SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)",
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=(
"What is the average, minimum, and maximum "
"age of all singers from France?"
)
),
PROMPT_TEMPLATE.format(
query=("What are the names of the stadiums without any concerts?")
),
]
input_templates = []
for prmpt in prompts:
messages = [{"role": "user", "content": prmpt}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # disable thinking
)
input_templates.append(prompt)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=512)
outputs = llm.generate(
input_templates,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@create_new_process_for_each_test()
def test_qwen35_dense_model_lora(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=512,
enable_lora=True,
max_loras=2,
max_num_seqs=16,
max_lora_rank=8,
trust_remote_code=True,
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@multi_gpu_test(num_gpus=4)
def test_qwen35_dense_model_lora_tp4(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
print(output1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@multi_gpu_test(num_gpus=4)
def test_qwen35_dense_model_lora_tp4_fully_sharded_loras(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=512,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
gpu_memory_utilization=0.8,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]