299 lines
7.8 KiB
Python
299 lines
7.8 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from tests.lora.utils import (
|
||
|
|
PunicaTensors,
|
||
|
|
assert_close,
|
||
|
|
generate_data,
|
||
|
|
generate_data_for_expand_nslices,
|
||
|
|
)
|
||
|
|
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||
|
|
from vllm.platforms import current_platform
|
||
|
|
|
||
|
|
|
||
|
|
def torch_bgmv_expand(
|
||
|
|
inputs: torch.Tensor,
|
||
|
|
lora_b_weights: torch.Tensor,
|
||
|
|
output_tensor: torch.Tensor,
|
||
|
|
lora_indices_tensor: torch.Tensor,
|
||
|
|
add_inputs: bool = True,
|
||
|
|
):
|
||
|
|
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||
|
|
if len(selected_loras.shape) == 4:
|
||
|
|
selected_loras = selected_loras.squeeze(dim=1)
|
||
|
|
inputs = inputs.to(dtype=output_tensor.dtype)
|
||
|
|
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||
|
|
|
||
|
|
limit = output_tensor.shape[0]
|
||
|
|
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||
|
|
limit = 1
|
||
|
|
|
||
|
|
# LoRA adapter and model may add different amounts of padding to output
|
||
|
|
common_len = min(outputs.shape[1], output_tensor.shape[1])
|
||
|
|
|
||
|
|
if add_inputs:
|
||
|
|
output_tensor[:, :common_len] += outputs[:limit, :common_len]
|
||
|
|
else:
|
||
|
|
output_tensor[:, :common_len] = outputs[:limit, :common_len]
|
||
|
|
|
||
|
|
|
||
|
|
def torch_bgmv_shrink(
|
||
|
|
inputs: torch.Tensor,
|
||
|
|
lora_b_weights: torch.Tensor,
|
||
|
|
output_tensor: torch.Tensor,
|
||
|
|
lora_indices_tensor: torch.Tensor,
|
||
|
|
scaling: float = 1.0,
|
||
|
|
):
|
||
|
|
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||
|
|
if len(selected_loras.shape) == 4:
|
||
|
|
selected_loras = selected_loras.squeeze(dim=1)
|
||
|
|
inputs = inputs.to(dtype=output_tensor.dtype)
|
||
|
|
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||
|
|
|
||
|
|
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
|
||
|
|
|
||
|
|
|
||
|
|
def torch_bgmv_expand_slice(
|
||
|
|
inputs: torch.Tensor,
|
||
|
|
lora_b_weights: torch.Tensor,
|
||
|
|
output_tensor: torch.Tensor,
|
||
|
|
lora_indices_tensor: torch.Tensor,
|
||
|
|
slice_offset: int,
|
||
|
|
slice_size: int,
|
||
|
|
add_inputs: bool = True,
|
||
|
|
):
|
||
|
|
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||
|
|
inputs = inputs.to(dtype=output_tensor.dtype)
|
||
|
|
if len(selected_loras.shape) == 4:
|
||
|
|
selected_loras = selected_loras.squeeze(dim=1)
|
||
|
|
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||
|
|
|
||
|
|
if add_inputs:
|
||
|
|
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
|
||
|
|
else:
|
||
|
|
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
|
||
|
|
|
||
|
|
|
||
|
|
def check_bgmv_shrink(
|
||
|
|
batches: int,
|
||
|
|
num_loras: int,
|
||
|
|
rank: int,
|
||
|
|
hidden_size: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
device: str,
|
||
|
|
scaling: float,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Compare vllm.bgmv_shrink against a reference implementation.
|
||
|
|
"""
|
||
|
|
seq_length = 1
|
||
|
|
data: PunicaTensors = generate_data(
|
||
|
|
batches,
|
||
|
|
hidden_size,
|
||
|
|
num_loras,
|
||
|
|
rank,
|
||
|
|
seq_length,
|
||
|
|
dtype,
|
||
|
|
"shrink",
|
||
|
|
device,
|
||
|
|
)
|
||
|
|
|
||
|
|
bgmv_shrink(
|
||
|
|
data.inputs_tensor,
|
||
|
|
data.lora_weights,
|
||
|
|
data.our_out_tensor,
|
||
|
|
data.token_lora_mapping,
|
||
|
|
scaling,
|
||
|
|
)
|
||
|
|
|
||
|
|
torch_bgmv_shrink(
|
||
|
|
data.inputs_tensor,
|
||
|
|
data.lora_weights,
|
||
|
|
data.ref_out_tensor,
|
||
|
|
data.token_lora_mapping,
|
||
|
|
scaling,
|
||
|
|
)
|
||
|
|
|
||
|
|
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
|
||
|
|
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||
|
|
|
||
|
|
|
||
|
|
def check_bgmv_expand(
|
||
|
|
batches: int,
|
||
|
|
num_loras: int,
|
||
|
|
rank: int,
|
||
|
|
hidden_size: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
device: str,
|
||
|
|
add_inputs: bool,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Compare vllm.bgmv_expand against a reference implementation.
|
||
|
|
"""
|
||
|
|
seq_length = 1
|
||
|
|
data: PunicaTensors = generate_data(
|
||
|
|
batches,
|
||
|
|
hidden_size,
|
||
|
|
num_loras,
|
||
|
|
rank,
|
||
|
|
seq_length,
|
||
|
|
dtype,
|
||
|
|
"expand",
|
||
|
|
device,
|
||
|
|
)
|
||
|
|
|
||
|
|
bgmv_expand(
|
||
|
|
data.inputs_tensor,
|
||
|
|
data.lora_weights,
|
||
|
|
data.our_out_tensor,
|
||
|
|
data.token_lora_mapping,
|
||
|
|
add_inputs=add_inputs,
|
||
|
|
)
|
||
|
|
torch_bgmv_expand(
|
||
|
|
data.inputs_tensor,
|
||
|
|
data.lora_weights,
|
||
|
|
data.ref_out_tensor,
|
||
|
|
data.token_lora_mapping,
|
||
|
|
add_inputs=add_inputs,
|
||
|
|
)
|
||
|
|
assert_close(data.ref_out_tensor, data.our_out_tensor)
|
||
|
|
|
||
|
|
|
||
|
|
def check_bgmv_expand_slice(
|
||
|
|
batches: int,
|
||
|
|
num_loras: int,
|
||
|
|
rank: int,
|
||
|
|
hidden_size: int,
|
||
|
|
nslices: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
device: str,
|
||
|
|
add_inputs: bool,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Compare vllm.bgmv_expand_slice against a reference implementation.
|
||
|
|
"""
|
||
|
|
seq_length = 1
|
||
|
|
data: PunicaTensors = generate_data_for_expand_nslices(
|
||
|
|
batches,
|
||
|
|
hidden_size,
|
||
|
|
num_loras,
|
||
|
|
rank,
|
||
|
|
seq_length,
|
||
|
|
dtype,
|
||
|
|
nslices,
|
||
|
|
device,
|
||
|
|
)
|
||
|
|
|
||
|
|
slice_offset = 0
|
||
|
|
for index in range(nslices):
|
||
|
|
bgmv_expand_slice(
|
||
|
|
data.inputs_tensor,
|
||
|
|
data.lora_weights[index],
|
||
|
|
data.our_out_tensor,
|
||
|
|
data.token_lora_mapping,
|
||
|
|
slice_offset,
|
||
|
|
slice_size=hidden_size,
|
||
|
|
add_inputs=add_inputs,
|
||
|
|
)
|
||
|
|
torch_bgmv_expand_slice(
|
||
|
|
data.inputs_tensor,
|
||
|
|
data.lora_weights[index],
|
||
|
|
data.ref_out_tensor,
|
||
|
|
data.token_lora_mapping,
|
||
|
|
slice_offset,
|
||
|
|
slice_size=hidden_size,
|
||
|
|
add_inputs=add_inputs,
|
||
|
|
)
|
||
|
|
|
||
|
|
slice_offset += hidden_size
|
||
|
|
assert_close(data.ref_out_tensor, data.our_out_tensor)
|
||
|
|
|
||
|
|
|
||
|
|
# General tests params that tests for variations in all dimensions
|
||
|
|
# except hidden_size.
|
||
|
|
test_params = {
|
||
|
|
"hidden_sizes": [2049],
|
||
|
|
"batches": [4],
|
||
|
|
"num_loras": [4],
|
||
|
|
"max_ranks": [32],
|
||
|
|
}
|
||
|
|
|
||
|
|
DTYPES = [torch.float16, torch.bfloat16]
|
||
|
|
DEVICES = [f"xpu:{0}"]
|
||
|
|
SEED = [0]
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("batches", test_params["batches"])
|
||
|
|
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
|
||
|
|
@pytest.mark.parametrize("rank", test_params["max_ranks"])
|
||
|
|
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
|
||
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||
|
|
@pytest.mark.parametrize("device", DEVICES)
|
||
|
|
@pytest.mark.parametrize("seed", SEED)
|
||
|
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||
|
|
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
|
||
|
|
def test_bgmv(
|
||
|
|
batches: int,
|
||
|
|
num_loras: int,
|
||
|
|
rank: int,
|
||
|
|
hidden_size: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
device: str,
|
||
|
|
seed: int,
|
||
|
|
op_type: str,
|
||
|
|
):
|
||
|
|
if op_type == "shrink":
|
||
|
|
check_bgmv_shrink(
|
||
|
|
batches=batches,
|
||
|
|
num_loras=num_loras,
|
||
|
|
rank=rank,
|
||
|
|
hidden_size=hidden_size,
|
||
|
|
dtype=dtype,
|
||
|
|
device=device,
|
||
|
|
scaling=0.5,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
check_bgmv_expand(
|
||
|
|
batches=batches,
|
||
|
|
num_loras=num_loras,
|
||
|
|
rank=rank,
|
||
|
|
hidden_size=hidden_size,
|
||
|
|
dtype=dtype,
|
||
|
|
device=device,
|
||
|
|
add_inputs=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("batches", test_params["batches"])
|
||
|
|
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
|
||
|
|
@pytest.mark.parametrize("rank", test_params["max_ranks"])
|
||
|
|
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
|
||
|
|
@pytest.mark.parametrize("nslices", [2, 3])
|
||
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||
|
|
@pytest.mark.parametrize("device", DEVICES)
|
||
|
|
@pytest.mark.parametrize("seed", SEED)
|
||
|
|
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
|
||
|
|
def test_bgmv_expand_nslices(
|
||
|
|
batches: int,
|
||
|
|
num_loras: int,
|
||
|
|
rank: int,
|
||
|
|
hidden_size: int,
|
||
|
|
nslices: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
device: str,
|
||
|
|
seed: int,
|
||
|
|
):
|
||
|
|
check_bgmv_expand_slice(
|
||
|
|
batches=batches,
|
||
|
|
num_loras=num_loras,
|
||
|
|
rank=rank,
|
||
|
|
hidden_size=hidden_size,
|
||
|
|
nslices=nslices,
|
||
|
|
dtype=dtype,
|
||
|
|
device=device,
|
||
|
|
add_inputs=True,
|
||
|
|
)
|