2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-05-22 21:44:18 -04:00
import subprocess
import sys
2025-02-02 14:58:18 -05:00
2025-10-20 05:21:09 +01:00
import pytest
2024-11-24 09:23:17 +08:00
import vllm
2025-10-20 05:21:09 +01:00
import vllm . config
2025-05-22 21:44:18 -04:00
from vllm import LLM
2024-11-24 09:23:17 +08:00
from vllm . lora . request import LoRARequest
2025-05-22 21:44:18 -04:00
from vllm . model_executor . model_loader . tensorizer import TensorizerConfig
2024-11-24 09:23:17 +08:00
2025-05-22 21:44:18 -04:00
from . . utils import VLLM_PATH , create_new_process_for_each_test , multi_gpu_test
2024-11-24 09:23:17 +08:00
MODEL_PATH = " meta-llama/Llama-2-7b-hf "
EXPECTED_LORA_OUTPUT = [
" SELECT icao FROM table_name_74 WHERE airport = ' lilongwe international airport ' " , # noqa: E501
2025-10-06 06:12:40 +01:00
" SELECT nationality FROM table_name_11 WHERE elector = ' anchero pantaleone ' " ,
2024-11-24 09:23:17 +08:00
" SELECT one_mora FROM table_name_95 WHERE gloss = ' low tone mora with a gloss of /˩okiru/ ' [òkìɽɯ ́] AND accented_mora = ' low tone mora with a gloss of /˩okiru/ ' [òkìɽɯ ́] " , # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) " , # noqa: E501
2025-10-06 06:12:40 +01:00
" SELECT pick FROM table_name_60 WHERE former_wnba_team = ' Minnesota Lynx ' " ,
2024-11-24 09:23:17 +08:00
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = ' Werner Schlager ' " , # noqa: E501
]
2025-05-22 21:44:18 -04:00
def do_sample (
llm : vllm . LLM ,
lora_path : str ,
lora_id : int ,
tensorizer_config_dict : dict | None = None ,
) - > list [ str ] :
2024-11-24 09:23:17 +08:00
prompts = [
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR) \n \n question: Name the ICAO for lilongwe international airport [/user] [assistant] " , # noqa: E501
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR) \n \n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant] " , # noqa: E501
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR) \n \n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ ́]? [/user] [assistant] " , # noqa: E501
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR) \n \n question: which gender got the highest average uncertain ratio. [/user] [assistant] " , # noqa: E501
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR) \n \n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant] " , # noqa: E501
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR) \n \n question: Name the women ' s doubles for werner schlager [/user] [assistant] " , # noqa: E501
]
2025-05-22 21:44:18 -04:00
2024-11-24 09:23:17 +08:00
sampling_params = vllm . SamplingParams (
2025-04-17 07:45:24 -07:00
temperature = 0 , max_tokens = 256 , skip_special_tokens = False , stop = [ " [/assistant] " ]
2024-11-24 09:23:17 +08:00
)
2025-05-22 21:44:18 -04:00
if tensorizer_config_dict is not None :
outputs = llm . generate (
prompts ,
sampling_params ,
lora_request = LoRARequest (
str ( lora_id ) ,
lora_id ,
lora_path ,
tensorizer_config_dict = tensorizer_config_dict ,
)
if lora_id
else None ,
)
else :
outputs = llm . generate (
prompts ,
sampling_params ,
lora_request = LoRARequest ( str ( lora_id ) , lora_id , lora_path )
if lora_id
else None ,
)
2024-11-24 09:23:17 +08:00
# Print the outputs.
2025-03-03 01:34:51 +00:00
generated_texts : list [ str ] = [ ]
2024-11-24 09:23:17 +08:00
for output in outputs :
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
generated_texts . append ( generated_text )
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
return generated_texts
2025-05-22 21:44:18 -04:00
def generate_and_test ( llm , sql_lora_files , tensorizer_config_dict : dict | None = None ) :
2024-11-24 09:23:17 +08:00
print ( " lora adapter created " )
print ( " lora 1 " )
2025-05-22 21:44:18 -04:00
assert (
do_sample (
llm ,
sql_lora_files ,
tensorizer_config_dict = tensorizer_config_dict ,
lora_id = 1 ,
)
== EXPECTED_LORA_OUTPUT
2025-10-05 15:06:22 +01:00
)
2024-11-24 09:23:17 +08:00
print ( " lora 2 " )
2025-05-22 21:44:18 -04:00
assert (
do_sample (
llm ,
sql_lora_files ,
tensorizer_config_dict = tensorizer_config_dict ,
lora_id = 2 ,
)
== EXPECTED_LORA_OUTPUT
2025-10-05 15:06:22 +01:00
)
2024-11-24 09:23:17 +08:00
print ( " removing lora " )
2025-03-17 19:33:35 +08:00
@create_new_process_for_each_test ( )
2025-10-20 05:21:09 +01:00
@pytest.mark.parametrize ( " cudagraph_specialize_lora " , [ True , False ] )
def test_llama_lora ( sql_lora_files , cudagraph_specialize_lora : bool ) :
2025-03-22 05:03:32 -04:00
llm = vllm . LLM (
MODEL_PATH ,
2025-09-17 01:42:59 -07:00
tokenizer = sql_lora_files ,
2025-03-22 05:03:32 -04:00
enable_lora = True ,
# also test odd max_num_seqs
max_num_seqs = 13 ,
2025-08-22 17:56:51 +08:00
max_loras = 4 ,
2025-10-20 05:21:09 +01:00
compilation_config = vllm . config . CompilationConfig (
cudagraph_specialize_lora = cudagraph_specialize_lora ,
) ,
2025-08-22 17:56:51 +08:00
)
2024-12-03 01:53:36 +08:00
generate_and_test ( llm , sql_lora_files )
2024-11-24 09:23:17 +08:00
@multi_gpu_test ( num_gpus = 4 )
def test_llama_lora_tp4 ( sql_lora_files ) :
llm = vllm . LLM (
MODEL_PATH ,
2025-09-17 01:42:59 -07:00
tokenizer = sql_lora_files ,
2024-11-24 09:23:17 +08:00
enable_lora = True ,
max_num_seqs = 16 ,
max_loras = 4 ,
tensor_parallel_size = 4 ,
)
2024-12-03 01:53:36 +08:00
generate_and_test ( llm , sql_lora_files )
2024-11-24 09:23:17 +08:00
@multi_gpu_test ( num_gpus = 4 )
def test_llama_lora_tp4_fully_sharded_loras ( sql_lora_files ) :
llm = vllm . LLM (
MODEL_PATH ,
2025-09-17 01:42:59 -07:00
tokenizer = sql_lora_files ,
2024-11-24 09:23:17 +08:00
enable_lora = True ,
max_num_seqs = 16 ,
max_loras = 4 ,
tensor_parallel_size = 4 ,
fully_sharded_loras = True ,
)
2024-12-03 01:53:36 +08:00
generate_and_test ( llm , sql_lora_files )
2025-05-22 21:44:18 -04:00
@multi_gpu_test ( num_gpus = 2 )
def test_tp2_serialize_and_deserialize_lora (
tmp_path , sql_lora_files , sql_lora_huggingface_id
) :
# Run the tensorizing of the LoRA adapter and the model in a subprocess
# to guarantee cleanup
tp_size = 2
model_name = " model-rank- %03d .tensors "
model_ref = MODEL_PATH
lora_path = sql_lora_huggingface_id
suffix = " test "
try :
result = subprocess . run (
[
sys . executable ,
2025-05-26 22:38:04 +08:00
f " { VLLM_PATH } /examples/others/tensorize_vllm_model.py " ,
" --model " ,
2025-05-22 21:44:18 -04:00
MODEL_PATH ,
" --lora-path " ,
lora_path ,
" --tensor-parallel-size " ,
str ( tp_size ) ,
" serialize " ,
" --serialized-directory " ,
2025-07-08 01:47:43 -04:00
str ( tmp_path ) ,
" --suffix " ,
suffix ,
" --serialization-kwargs " ,
' { " limit_cpu_concurrency " : 4} ' ,
2025-05-22 21:44:18 -04:00
] ,
check = True ,
capture_output = True ,
text = True ,
)
except subprocess . CalledProcessError as e :
print ( " Tensorizing failed. " )
print ( " STDOUT: \n " , e . stdout )
print ( " STDERR: \n " , e . stderr )
raise
print ( " STDOUT: \n " , result . stdout )
model_uri = tmp_path / " vllm " / model_ref / suffix / model_name
tensorizer_config = TensorizerConfig ( tensorizer_uri = str ( model_uri ) )
2025-07-21 19:18:33 +08:00
loaded_llm = LLM (
model = model_ref ,
2025-09-17 01:42:59 -07:00
tokenizer = sql_lora_files ,
2025-07-21 19:18:33 +08:00
load_format = " tensorizer " ,
enable_lora = True ,
enforce_eager = True ,
model_loader_extra_config = tensorizer_config ,
max_num_seqs = 13 ,
tensor_parallel_size = 2 ,
max_loras = 2 ,
)
2025-05-22 21:44:18 -04:00
2025-07-10 15:07:06 -04:00
tc_as_dict = tensorizer_config . to_serializable ( )
2025-05-22 21:44:18 -04:00
print ( " lora adapter created " )
print ( " lora 1 " )
2025-07-21 19:18:33 +08:00
assert (
do_sample (
2025-07-10 15:07:06 -04:00
loaded_llm , sql_lora_files , tensorizer_config_dict = tc_as_dict , lora_id = 1
2025-05-22 21:44:18 -04:00
)
== EXPECTED_LORA_OUTPUT
2025-10-05 15:06:22 +01:00
)