[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
|
||||
]
|
||||
|
||||
PPLX_COMBOS = [
|
||||
# TODO: figure out why this fails, seems to be test problem
|
||||
# TODO(bnell): figure out why this fails, seems to be test problem
|
||||
#(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
|
||||
|
||||
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
||||
a_chunk,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
FusedMoEQuantConfig(
|
||||
FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
False,
|
||||
block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -540,20 +540,6 @@ def pplx_moe(
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
a_chunk = chunk_by_rank(a, rank, world_size)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
|
||||
@@ -567,6 +553,28 @@ def pplx_moe(
|
||||
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
|
||||
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
@@ -585,10 +593,6 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
@@ -605,10 +609,6 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
||||
@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
args["w1"] = w1
|
||||
args["w2"] = w2
|
||||
|
||||
Reference in New Issue
Block a user