Add TMA support to fused_moe_lora kernel (#32195)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack
2026-02-28 18:55:25 -08:00
committed by GitHub
parent e3eb146f7a
commit 3ecd0bf9fc
5 changed files with 278 additions and 51 deletions

View File

@@ -231,17 +231,22 @@ def use_torch(
lora_a_stacked,
lora_b_stacked,
top_k_num,
num_slices=1,
):
outputs = []
for i in range(hidden_states.shape[0]):
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
]
outputs.append(torch.stack(tensors, dim=0))
slice_tensors = []
for slice_id in range(num_slices):
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_a = lora_a_stacked[slice_id][lora_idx][expert_ids]
lora_b = lora_b_stacked[slice_id][lora_idx][expert_ids]
tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
]
slice_tensors.append(torch.stack(tensors, dim=0))
outputs.append(torch.concat(slice_tensors, dim=-1))
return torch.stack(outputs, dim=0)
@@ -259,6 +264,7 @@ SEED = [42]
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_slices", [1, 2])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel(
K,
max_lora_rank,
block_size,
num_slices,
dtype,
device,
seed,
@@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel(
),
dtype=dtype,
)
for _ in range(num_slices)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
N // num_slices,
max_lora_rank,
),
dtype=dtype,
)
for _ in range(num_slices)
]
hidden_states = torch.rand(
(
@@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel(
lora_a_stacked,
lora_b_stacked,
top_k_num,
num_slices,
)
torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2)
@@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive(
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_slices", [1, 2])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
K,
max_lora_rank,
block_size,
num_slices,
dtype,
device,
seed,
@@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
),
dtype=dtype,
)
for _ in range(num_slices)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
N // num_slices,
max_lora_rank,
),
dtype=dtype,
)
for _ in range(num_slices)
]
hidden_states = torch.rand(
(
@@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
lora_a_stacked,
lora_b_stacked,
top_k_num,
num_slices,
)
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)

View File

@@ -2,7 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import shutil
import pytest
import torch
from safetensors.torch import load_file, save_file
import vllm
from vllm.lora.request import LoRARequest
@@ -122,6 +126,41 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
def test_olmoe_lora_mixed_random(olmoe_lora_files, tmp_path):
# Create a dummy LoRA with random weights based on the real one
random_lora_path = tmp_path / "random_lora"
shutil.copytree(olmoe_lora_files, random_lora_path)
weights_path = random_lora_path / "adapter_model.safetensors"
weights = load_file(str(weights_path))
random_weights = {k: torch.randn_like(v) for k, v in weights.items()}
save_file(random_weights, str(weights_path))
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
]
lora_requests = [
LoRARequest("real", 1, olmoe_lora_files),
LoRARequest("random", 2, str(random_lora_path)),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
assert outputs[0].outputs[0].text.strip().startswith(EXPECTED_LORA_OUTPUT[0])
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):