[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-07-03 17:55:40 -04:00
committed by GitHub
parent 2f2fcb31b8
commit 78fe77534b
25 changed files with 1277 additions and 663 deletions

View File

@@ -63,13 +63,12 @@ def batched_moe(
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
world_size=1,
dp_size=1,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
@@ -105,13 +104,12 @@ def naive_batched_moe(
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,