[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (#14238)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
0
tests/tpu/lora/__init__.py
Normal file
0
tests/tpu/lora/__init__.py
Normal file
124
tests/tpu/lora/test_lora.py
Normal file
124
tests/tpu/lora/test_lora.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# This file contains tests to ensure that LoRA works correctly on the TPU
|
||||
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
|
||||
# for this. The adapters are:
|
||||
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
|
||||
# from 1 to 4.
|
||||
|
||||
# These adapters are trained using a standard huggingface peft training script,
|
||||
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
|
||||
# 100 training iterations with a training batch size of 100.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
|
||||
for all tests in this file
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
yield
|
||||
|
||||
|
||||
def setup_vllm(num_loras: int) -> vllm.LLM:
|
||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||
num_scheduler_steps=1,
|
||||
max_model_len=256,
|
||||
max_seq_len_to_capture=256,
|
||||
max_num_seqs=8,
|
||||
enable_lora=True,
|
||||
max_loras=num_loras,
|
||||
max_lora_rank=8)
|
||||
|
||||
|
||||
def test_single_lora():
|
||||
"""
|
||||
This test ensures we can run a single LoRA adapter on the TPU backend.
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=1.
|
||||
"""
|
||||
|
||||
llm = setup_vllm(1)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
lora_request = LoRARequest(
|
||||
"lora_adapter_1", 1,
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
|
||||
output = llm.generate(prompt,
|
||||
sampling_params=vllm.SamplingParams(max_tokens=256,
|
||||
temperature=0),
|
||||
lora_request=lora_request)[0].outputs[0].text
|
||||
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(answer) == 1
|
||||
|
||||
|
||||
def test_lora_hotswapping():
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
||||
if we only have space to store 1.
|
||||
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
||||
"""
|
||||
|
||||
lora_name_template = \
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
||||
lora_requests = [
|
||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(1)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
for i, req in enumerate(lora_requests):
|
||||
output = llm.generate(prompt,
|
||||
sampling_params=vllm.SamplingParams(
|
||||
max_tokens=256, temperature=0),
|
||||
lora_request=req)[0].outputs[0].text
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(answer) == i + 1
|
||||
|
||||
|
||||
def test_multi_lora():
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
||||
we have enough space to store all of them.
|
||||
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
||||
"""
|
||||
lora_name_template = \
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
||||
lora_requests = [
|
||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(4)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
for i, req in enumerate(lora_requests):
|
||||
output = llm.generate(prompt,
|
||||
sampling_params=vllm.SamplingParams(
|
||||
max_tokens=256, temperature=0),
|
||||
lora_request=req)[0].outputs[0].text
|
||||
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(output.strip()[0]) == i + 1
|
||||
73
tests/tpu/lora/test_pallas_kernels.py
Normal file
73
tests/tpu/lora/test_pallas_kernels.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Required to register the custom ops
|
||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
||||
|
||||
N_TOKENS = [16, 1024, 4096]
|
||||
HIDDEN_SIZES = [1024, 2048, 4096]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_LORA = [1, 4, 16]
|
||||
RANKS = [32, 256, 512]
|
||||
|
||||
|
||||
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
|
||||
"""
|
||||
Inputs: (All integers)
|
||||
T: Total number of tokens
|
||||
D: Input dim
|
||||
L: LoRA Dim
|
||||
N: N LoRAs
|
||||
|
||||
Outputs:
|
||||
inputs: torch.Tensor - shape (T, D)
|
||||
loras: torch.Tensor - shape (N, 1, L, D)
|
||||
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
|
||||
|
||||
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
inputs = torch.randn((T, D), device="xla", dtype=dtype)
|
||||
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
|
||||
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
|
||||
|
||||
ref_output = ref_bgmv(inputs, loras, idxs)
|
||||
return inputs, loras, idxs, ref_output
|
||||
|
||||
|
||||
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
|
||||
selected_loras = loras[idxs]
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(axis=1)
|
||||
|
||||
batch_size, output_size, input_size = selected_loras.shape
|
||||
return (selected_loras @ inputs.reshape(
|
||||
(batch_size, input_size, 1))).reshape((batch_size, output_size))
|
||||
|
||||
|
||||
# Parameterize tests with various shapes and dtypes
|
||||
@pytest.mark.parametrize("T", N_TOKENS)
|
||||
@pytest.mark.parametrize("D", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("L", RANKS)
|
||||
@pytest.mark.parametrize("N", NUM_LORA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
|
||||
if op_type == "expand":
|
||||
D, L = L, D
|
||||
|
||||
inputs, loras, idxs, ref_output = generate_test_data(
|
||||
T, D, L, N, seed, dtype)
|
||||
|
||||
# Run bgmv
|
||||
output = torch.ops.xla.bgmv(inputs, loras, idxs)
|
||||
|
||||
# Make sure we have no NaNs
|
||||
assert not torch.any(torch.isnan(output))
|
||||
|
||||
# Compare with reference output
|
||||
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
|
||||
Reference in New Issue
Block a user