[Kernel] Enable moe LoRA kernel support FP16 (#27468)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-10-27 19:48:37 +08:00
committed by GitHub
parent a663f6ae64
commit f4e8154076
2 changed files with 26 additions and 16 deletions

View File

@@ -204,6 +204,11 @@ def use_torch(
return torch.stack(outputs, dim=0)
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [42]
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6, 12])
@pytest.mark.parametrize("num_experts", [64])
@@ -212,6 +217,9 @@ def use_torch(
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_fused_moe_lora_kernel(
num_tokens,
top_k_num,
@@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
K,
max_lora_rank,
block_size,
dtype,
device,
seed,
):
torch.set_default_device("cuda:0")
current_platform.seed_everything(42)
torch.set_default_device(device)
current_platform.seed_everything(seed)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
@@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
max_lora_rank,
K,
),
dtype=torch.bfloat16,
dtype=dtype,
)
]
lora_b_stacked = [
@@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
N,
max_lora_rank,
),
dtype=torch.bfloat16,
dtype=dtype,
)
]
hidden_states = torch.rand(
@@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
num_tokens,
K,
),
dtype=torch.bfloat16,
dtype=dtype,
)
# fused_moe_lora_kernel output
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
use_fused_moe_lora_kernel(
topk_ids,
topk_weights,