[Kernel] Enable moe LoRA kernel support FP16 (#27468)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -204,6 +204,11 @@ def use_torch(
|
||||
return torch.stack(outputs, dim=0)
|
||||
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
SEED = [42]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [100])
|
||||
@pytest.mark.parametrize("top_k_num", [6, 12])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@@ -212,6 +217,9 @@ def use_torch(
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_fused_moe_lora_kernel(
|
||||
num_tokens,
|
||||
top_k_num,
|
||||
@@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
):
|
||||
torch.set_default_device("cuda:0")
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
# the number of randomly generated sentences.
|
||||
num_sequences = 10
|
||||
# generate data
|
||||
@@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
|
||||
max_lora_rank,
|
||||
K,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
@@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
|
||||
N,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
@@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
|
||||
num_tokens,
|
||||
K,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# fused_moe_lora_kernel output
|
||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
|
||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
|
||||
use_fused_moe_lora_kernel(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
|
||||
Reference in New Issue
Block a user