[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales = 2 * n if per_out_channel else 1
k_b_scales = k if per_out_channel else 1
# Get the right scale for tests.
_, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
a_scale,
use_per_token_if_dynamic=per_act_token)
a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
num_experts = moe_tensors.w1.size(0)
@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
atol=5.5e-2,
rtol=1e-2)
@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
torch.cuda.synchronize()
graph.replay()
@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output = run_8_bit(mt,
topk_weights,
topk_ids,
per_act_token,
num_local_experts=e // ep_size)
torch.testing.assert_close(triton_output,