[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-07-03 17:55:40 -04:00
committed by GitHub
parent 2f2fcb31b8
commit 78fe77534b
25 changed files with 1277 additions and 663 deletions

View File

@@ -10,7 +10,7 @@ import triton.language as tl
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
make_test_weights, naive_batched_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
@@ -33,12 +33,10 @@ MNK_FACTORS = [
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 128, 128),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 512, 512),
(222, 1024, 128),
(222, 1024, 2048),
]
@@ -95,11 +93,12 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("N", [128, 256, 1024])
@pytest.mark.parametrize(
"dtype",
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
@@ -134,7 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)
per_act_token_quant=per_act_token_quant,
)
B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
@@ -143,6 +143,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
out_shape = (num_experts, max_tokens_per_expert, N)
@@ -177,6 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
},
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
@@ -185,15 +187,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)
block_shape,
per_act_token_quant)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@@ -201,16 +201,17 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
def test_fused_moe_batched_experts(
m: int,
n: int,
@@ -220,15 +221,19 @@ def test_fused_moe_batched_experts(
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
input_scales: bool,
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if topk > e:
pytest.skip("topk > e")
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None or topk > e:
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
@@ -241,27 +246,26 @@ def test_fused_moe_batched_experts(
act_dtype = dtype
quant_dtype = None
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
)
if input_scales and quant_dtype is not None:
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
baseline_output = torch_experts(
a,
w1,
@@ -270,11 +274,14 @@ def test_fused_moe_batched_experts(
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
block_shape=block_shape,
)
triton_output = triton_moe(
batched_output = naive_batched_moe(
a,
w1,
w2,
@@ -282,14 +289,31 @@ def test_fused_moe_batched_experts(
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(triton_output,
triton_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(batched_output,
baseline_output,
atol=2e-2,
atol=3e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output,