Add TMA support to fused_moe_lora kernel (#32195)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -231,17 +231,22 @@ def use_torch(
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
num_slices=1,
|
||||
):
|
||||
outputs = []
|
||||
for i in range(hidden_states.shape[0]):
|
||||
lora_idx = token_lora_mapping[i]
|
||||
expert_ids = topk_ids[i]
|
||||
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
|
||||
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
|
||||
tensors = [
|
||||
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
|
||||
]
|
||||
outputs.append(torch.stack(tensors, dim=0))
|
||||
slice_tensors = []
|
||||
for slice_id in range(num_slices):
|
||||
lora_idx = token_lora_mapping[i]
|
||||
expert_ids = topk_ids[i]
|
||||
lora_a = lora_a_stacked[slice_id][lora_idx][expert_ids]
|
||||
lora_b = lora_b_stacked[slice_id][lora_idx][expert_ids]
|
||||
tensors = [
|
||||
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
|
||||
]
|
||||
slice_tensors.append(torch.stack(tensors, dim=0))
|
||||
|
||||
outputs.append(torch.concat(slice_tensors, dim=-1))
|
||||
return torch.stack(outputs, dim=0)
|
||||
|
||||
|
||||
@@ -259,6 +264,7 @@ SEED = [42]
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_slices", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel(
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
num_slices,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
@@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel(
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
N,
|
||||
N // num_slices,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
@@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel(
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
num_slices,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2)
|
||||
@@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive(
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_slices", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
num_slices,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
@@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
N,
|
||||
N // num_slices,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
@@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
num_slices,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -122,6 +126,41 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
|
||||
|
||||
|
||||
def test_olmoe_lora_mixed_random(olmoe_lora_files, tmp_path):
|
||||
# Create a dummy LoRA with random weights based on the real one
|
||||
random_lora_path = tmp_path / "random_lora"
|
||||
shutil.copytree(olmoe_lora_files, random_lora_path)
|
||||
|
||||
weights_path = random_lora_path / "adapter_model.safetensors"
|
||||
weights = load_file(str(weights_path))
|
||||
random_weights = {k: torch.randn_like(v) for k, v in weights.items()}
|
||||
save_file(random_weights, str(weights_path))
|
||||
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
|
||||
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
|
||||
]
|
||||
|
||||
lora_requests = [
|
||||
LoRARequest("real", 1, olmoe_lora_files),
|
||||
LoRARequest("random", 2, str(random_lora_path)),
|
||||
]
|
||||
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
|
||||
assert outputs[0].outputs[0].text.strip().startswith(EXPECTED_LORA_OUTPUT[0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
|
||||
|
||||
Reference in New Issue
Block a user