[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
]
PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
a1_scale,
a2_scale,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
FusedMoEQuantConfig(
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant,
False,
block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
),
)
@@ -540,20 +540,6 @@ def pplx_moe(
topk_ids = topk_ids.to(dtype=torch.uint32)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
@@ -567,6 +553,28 @@ def pplx_moe(
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
@@ -585,10 +593,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
@@ -605,10 +609,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2