Files
DeepGEMM/tests/test_mega_moe.py
Zhean Xu 891d57b4db Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2026-04-24 18:41:37 +08:00

296 lines
14 KiB
Python

import argparse
import os
import random
import sys
import torch
import torch.distributed as dist
from typing import Tuple
import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto
def import_baseline():
# Load legacy implements from third-party
deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False
# noinspection PyBroadException
try:
import deep_ep
import importlib.util
from tilelang.profiler.bench import do_bench
spec = importlib.util.spec_from_file_location(
'tilelang_ops',
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
tilelang_ops = importlib.util.module_from_spec(spec)
sys.modules['tilelang_ops'] = tilelang_ops
spec.loader.exec_module(tilelang_ops)
is_legacy_loaded = True
except Exception as ex:
dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True)
dist_print(once_in_node=True)
return deep_ep, tilelang_ops, do_bench, is_legacy_loaded
# TODO: skip the test for SM90
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
torch.manual_seed(rank_idx)
random.seed(rank_idx)
# Settings
num_max_tokens_per_rank = args.num_max_tokens_per_rank
num_tokens = max(0, args.num_max_tokens_per_rank - random.randint(0, args.num_max_removed_tokens)) \
if args.num_tokens == 0 else args.num_tokens
hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
num_experts, num_topk = args.num_experts, args.num_topk
num_experts_per_rank = num_experts // num_ranks
assert num_tokens <= num_max_tokens_per_rank
# Allocate symmetric memory
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden
)
# Create inputs
# noinspection PyGlobalUndefined
def create_inputs():
global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights
global cumulative_local_expert_recv_stats_fused
global cumulative_local_expert_recv_stats_baseline
x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
l1_weights = torch.randn(
(num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda')
l2_weights = torch.randn(
(num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda')
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda')
topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)
cumulative_local_expert_recv_stats_fused = torch.randint(
0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda')
cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone()
if args.masked_ratio > 0:
rand_mask = torch.rand_like(topk_idx, dtype=torch.float)
topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1)
topk_weights.masked_fill_(topk_idx < 0, 0)
# Check SF requirements
assert hidden % 128 == 0
assert intermediate_hidden % 128 == 0
assert l1_weights.shape[2] % 128 == 0 and l2_weights.shape[2] % 128 == 0
# Cast inputs to FP8 with per-32 UE8M0 SF
x = per_token_cast_to_fp8(x, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
# Cast grouped BF16 weights to FP4 with MN-major SF
# TODO: merge with `cast_fp8_fp4_with_major`
def cast_grouped_weights_to_fp4(bf16_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
num_groups, n, k = bf16_weights.shape
w = torch.empty((num_groups, n, k // 2), device='cuda', dtype=torch.int8)
w_sf = torch.empty((num_groups, n, k // 32), device='cuda', dtype=torch.float)
for i in range(num_groups):
w[i], w_sf[i] = per_token_cast_to_fp4(bf16_weights[i], use_ue8m0=True, gran_k=32)
w_sf = deep_gemm.transform_sf_into_required_layout(w_sf, n, k, (1, 32), num_groups)
return w, w_sf
l1_weights = cast_grouped_weights_to_fp4(l1_weights)
l2_weights = cast_grouped_weights_to_fp4(l2_weights)
transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
# Run fused mega MoE
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
def run_fused():
buffer.x[:num_tokens].copy_(x[0])
buffer.x_sf[:num_tokens].copy_(x[1])
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
# noinspection PyTypeChecker
deep_gemm.fp8_fp4_mega_moe(
y,
transformed_l1_weights, transformed_l2_weights,
buffer,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused,
activation_clamp=args.activation_clamp,
fast_math=bool(args.fast_math)
)
return y, cumulative_local_expert_recv_stats_fused
dist_print('Config:', once_in_node=True)
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
dist_print(f' > Hidden: {hidden}', once_in_node=True)
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
dist_print(once_in_node=True)
# Only do NCU profiling
if args.ncu_profile_only:
create_inputs()
dist_print(f'Run fused kernel:', once_in_node=True)
run_fused()
dist_print(f' > Done, exiting', once_in_node=True)
# Destroy and exit
dist.barrier()
buffer.destroy()
dist.destroy_process_group()
return
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline()
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
ep_buffer = deep_ep.ElasticBuffer(
group,
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
num_topk=num_topk, use_fp8_dispatch=True,
explicitly_destroy=True,
allow_multiple_reduction=False,
gpu_timeout_secs=10, cpu_timeout_secs=30
) if is_legacy_loaded else None
def run_baseline():
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
x, topk_idx=topk_idx, topk_weights=topk_weights,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline,
num_experts=num_experts, expert_alignment=alignment,
do_cpu_sync=False, do_handle_copy=False,
do_expand=True, use_tma_aligned_col_major_sf=True,
)
n = recv_x[0].size(0)
l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda')
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
recv_x, l1_weights, l1_y, handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True, recipe=(1, 1, 32))
# noinspection PyCallingNonCallable
l1_y = tilelang_ops.swiglu_apply_weight_to_fp8(
x=l1_y,
topk_weights=recv_topk_weights,
avail_tokens=handle.psum_num_recv_tokens_per_expert[-1],
num_per_channels=32,
use_col_major_scales=True,
round_scale=True,
ue8m0_scale=True,
output_bf16=False,
clamp_value=args.activation_clamp,
fast_math=bool(args.fast_math)
)
l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device='cuda')
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True, recipe=(1, 1, 32))
return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline
# Check correctness (must be bitwise identical)
num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests
# noinspection PyBroadException
if is_legacy_loaded and num_correctness_tests > 0:
dist_print('Running correctness tests:', once_in_node=True)
for i in range(num_correctness_tests):
create_inputs()
for fused_result, baseline_result in zip(run_fused(), run_baseline()):
assert torch.equal(fused_result, baseline_result)
if (i + 1) % 100 == 0 or i == num_correctness_tests - 1:
dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True)
dist_print(once_in_node=True)
else:
create_inputs()
# Count local received tokens
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \
(gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1
num_recv_tokens = (gathered_topk_idx != -1).sum().item()
# Benchmark
t_fused = bench_kineto(
run_fused, 'mega_moe',
barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(),
trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json')
t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
# TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K
safe_div = lambda a, b: float('nan') if b == 0 else a / b
tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused)
# HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes)
num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1"
num_hbm_bytes = (
num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
num_recv_tokens * hidden + # L1 acts read (FP8)
num_recv_tokens * intermediate_hidden + # L1 output write (FP8)
num_recv_tokens * intermediate_hidden + # L2 acts read (FP8)
num_recv_tokens * hidden * 2 # L2 output write (BF16)
)
hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)
# NVLink bytes: dispatch pull + combine write-back
num_nvlink_bytes = num_recv_tokens * hidden * 3
nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)
# Combine reduction (serial) time approximation
t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12
# Summary
approx_factor = t_fused / (t_fused - t_reduction)
dist_print('Performance:', once_in_node=True)
dist_print(f' > EP: {rank_idx:2}/{num_ranks} | '
f'{tflops:4.0f} TFLOPS | '
f'overlap: '
f'{tflops * approx_factor:4.0f} TFLOPS, '
f'HBM {hbm_gbs * approx_factor:4.0f} GB/s, '
f'NVL {nvlink_gbs * approx_factor:3.0f} GB/s | '
f'{t_fused * 1e6:4.0f} us, '
f'reduction: {t_reduction * 1e6:4.1f} us | '
f'{safe_div(t_baseline, t_fused):.2f}x legacy')
# Exit
dist.barrier()
buffer.destroy()
ep_buffer.destroy() if is_legacy_loaded else None
dist.destroy_process_group()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory')
# Resource settings
parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
# Model settings
parser.add_argument('--num-max-tokens-per-rank', type=int, default=8192, help='Number of maximum tokens per rank')
parser.add_argument('--num-tokens', type=int, default=0, help='Number of tokens per rank (follow max minus removed if 0)')
parser.add_argument('--num-max-removed-tokens', type=int, default=0, help='Maximum number of tokens to remove')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden size')
parser.add_argument('--intermediate-hidden', type=int, default=3072, help='Intermediate hidden size')
parser.add_argument('--activation-clamp', type=float, default=10, help='Clamp value for activation')
parser.add_argument('--num-experts', type=int, default=384, help='Number of experts')
parser.add_argument('--num-topk', type=int, default=6, help='Number of expert selections')
parser.add_argument('--masked-ratio', type=float, default=0.0, help='Mask some expert selections')
parser.add_argument('--fast-math', type=int, default=1, help='Enable fast math (0 or 1, default: 1)')
# Test settings
parser.add_argument('--num-correctness-tests', type=int, default=None, help='Pressure test')
parser.add_argument('--dump-profile-traces', type=str, default='', help='Dump profiling trace JSONs')
parser.add_argument('--local-rank-idx', type=int, default=None, help='Run as single process with this local rank (e.g. for NCU prof)')
args = parser.parse_args()
# Create dump trace directories
if args.dump_profile_traces:
os.makedirs(args.dump_profile_traces, exist_ok=True)
if args.local_rank_idx is not None:
# Single-process mode: each process is launched separately (e.g. by NCU)
test(args.local_rank_idx, args.num_processes, args)
else:
# Launch tests
num_processes = args.num_processes
torch.multiprocessing.spawn(test, args=(num_processes, args), nprocs=num_processes)