2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2024-11-12 11:08:40 -08:00
import pytest
import vllm
from vllm . lora . request import LoRARequest
MODEL_PATH = " ibm-granite/granite-3b-code-base "
2025-03-03 01:34:51 +00:00
def do_sample ( llm : vllm . LLM , lora_path : str , lora_id : int ) - > list [ str ] :
2024-11-12 11:08:40 -08:00
prompts = [
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR) \n \n question: which gender got the highest average uncertain ratio. [/user] [assistant] " , # noqa: E501
" [user] Write a SQL query to answer the question based on the table schema. \n \n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR) \n \n question: Name the women ' s doubles for werner schlager [/user] [assistant] " # noqa: E501
]
sampling_params = vllm . SamplingParams ( temperature = 0 ,
max_tokens = 256 ,
stop = [ " [/assistant] " ] )
outputs = llm . generate (
prompts ,
sampling_params ,
lora_request = LoRARequest ( str ( lora_id ) , lora_id , lora_path )
if lora_id else None )
2025-03-03 01:34:51 +00:00
generated_texts : list [ str ] = [ ]
2024-11-12 11:08:40 -08:00
for output in outputs :
generated_text = output . outputs [ 0 ] . text
generated_texts . append ( generated_text )
return generated_texts
2025-02-06 23:02:51 +05:30
@pytest.fixture ( autouse = True )
def v1 ( run_with_both_engines_lora ) :
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
2024-11-12 11:08:40 -08:00
@pytest.mark.parametrize ( " lora_bias " , [ True ] )
@pytest.mark.parametrize ( " fully_sharded " , [ True , False ] )
def test_lora_bias ( lora_bias_files : str , lora_bias : bool , fully_sharded : bool ) :
llm = vllm . LLM ( MODEL_PATH ,
enable_lora = True ,
max_num_seqs = 16 ,
max_lora_rank = 8 ,
max_loras = 1 ,
enable_lora_bias = lora_bias ,
tensor_parallel_size = 1 ,
fully_sharded_loras = fully_sharded )
print ( " lora adapter created " )
output1 = do_sample ( llm , lora_bias_files , lora_id = 0 )
print ( " lora " )
output2 = do_sample ( llm , lora_bias_files , lora_id = 1 )
if lora_bias :
assert output1 != output2
else :
assert output1 == output2