Add TMA support to fused_moe_lora kernel (#32195)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack
2026-02-28 18:55:25 -08:00
committed by GitHub
parent e3eb146f7a
commit 3ecd0bf9fc
5 changed files with 278 additions and 51 deletions

View File

@@ -231,17 +231,22 @@ def use_torch(
lora_a_stacked,
lora_b_stacked,
top_k_num,
num_slices=1,
):
outputs = []
for i in range(hidden_states.shape[0]):
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
]
outputs.append(torch.stack(tensors, dim=0))
slice_tensors = []
for slice_id in range(num_slices):
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_a = lora_a_stacked[slice_id][lora_idx][expert_ids]
lora_b = lora_b_stacked[slice_id][lora_idx][expert_ids]
tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
]
slice_tensors.append(torch.stack(tensors, dim=0))
outputs.append(torch.concat(slice_tensors, dim=-1))
return torch.stack(outputs, dim=0)
@@ -259,6 +264,7 @@ SEED = [42]
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_slices", [1, 2])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel(
K,
max_lora_rank,
block_size,
num_slices,
dtype,
device,
seed,
@@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel(
),
dtype=dtype,
)
for _ in range(num_slices)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
N // num_slices,
max_lora_rank,
),
dtype=dtype,
)
for _ in range(num_slices)
]
hidden_states = torch.rand(
(
@@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel(
lora_a_stacked,
lora_b_stacked,
top_k_num,
num_slices,
)
torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2)
@@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive(
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_slices", [1, 2])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
K,
max_lora_rank,
block_size,
num_slices,
dtype,
device,
seed,
@@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
),
dtype=dtype,
)
for _ in range(num_slices)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
N // num_slices,
max_lora_rank,
),
dtype=dtype,
)
for _ in range(num_slices)
]
hidden_states = torch.rand(
(
@@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
lora_a_stacked,
lora_b_stacked,
top_k_num,
num_slices,
)
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)