[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
_, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
|
||||
a_scale,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'a1_scale': moe_tensors.a_scale
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
|
||||
cutlass_output = run_8_bit(mt,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_local_experts=e // ep_size)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
|
||||
Reference in New Issue
Block a user