* Merge with private repo * Add Mega MoE Benchmark * Minor fix * Update --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
296 lines
14 KiB
Python
296 lines
14 KiB
Python
import argparse
|
|
import os
|
|
import random
|
|
import sys
|
|
import torch
|
|
import torch.distributed as dist
|
|
from typing import Tuple
|
|
|
|
import deep_gemm
|
|
from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8
|
|
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
|
|
from deep_gemm.testing import bench_kineto
|
|
|
|
|
|
def import_baseline():
|
|
# Load legacy implements from third-party
|
|
deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False
|
|
# noinspection PyBroadException
|
|
try:
|
|
import deep_ep
|
|
import importlib.util
|
|
from tilelang.profiler.bench import do_bench
|
|
spec = importlib.util.spec_from_file_location(
|
|
'tilelang_ops',
|
|
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
|
|
tilelang_ops = importlib.util.module_from_spec(spec)
|
|
sys.modules['tilelang_ops'] = tilelang_ops
|
|
spec.loader.exec_module(tilelang_ops)
|
|
is_legacy_loaded = True
|
|
except Exception as ex:
|
|
dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True)
|
|
dist_print(once_in_node=True)
|
|
return deep_ep, tilelang_ops, do_bench, is_legacy_loaded
|
|
|
|
|
|
# TODO: skip the test for SM90
|
|
# noinspection PyUnboundLocalVariable,PyShadowingNames
|
|
def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
|
rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
|
torch.manual_seed(rank_idx)
|
|
random.seed(rank_idx)
|
|
|
|
# Settings
|
|
num_max_tokens_per_rank = args.num_max_tokens_per_rank
|
|
num_tokens = max(0, args.num_max_tokens_per_rank - random.randint(0, args.num_max_removed_tokens)) \
|
|
if args.num_tokens == 0 else args.num_tokens
|
|
hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
|
|
num_experts, num_topk = args.num_experts, args.num_topk
|
|
num_experts_per_rank = num_experts // num_ranks
|
|
assert num_tokens <= num_max_tokens_per_rank
|
|
|
|
# Allocate symmetric memory
|
|
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
|
|
group, num_experts,
|
|
num_max_tokens_per_rank, num_topk,
|
|
hidden, intermediate_hidden
|
|
)
|
|
|
|
# Create inputs
|
|
# noinspection PyGlobalUndefined
|
|
def create_inputs():
|
|
global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights
|
|
global cumulative_local_expert_recv_stats_fused
|
|
global cumulative_local_expert_recv_stats_baseline
|
|
x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
|
l1_weights = torch.randn(
|
|
(num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda')
|
|
l2_weights = torch.randn(
|
|
(num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda')
|
|
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda')
|
|
topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)
|
|
cumulative_local_expert_recv_stats_fused = torch.randint(
|
|
0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda')
|
|
cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone()
|
|
if args.masked_ratio > 0:
|
|
rand_mask = torch.rand_like(topk_idx, dtype=torch.float)
|
|
topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1)
|
|
topk_weights.masked_fill_(topk_idx < 0, 0)
|
|
|
|
# Check SF requirements
|
|
assert hidden % 128 == 0
|
|
assert intermediate_hidden % 128 == 0
|
|
assert l1_weights.shape[2] % 128 == 0 and l2_weights.shape[2] % 128 == 0
|
|
|
|
# Cast inputs to FP8 with per-32 UE8M0 SF
|
|
x = per_token_cast_to_fp8(x, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
|
|
|
|
# Cast grouped BF16 weights to FP4 with MN-major SF
|
|
# TODO: merge with `cast_fp8_fp4_with_major`
|
|
def cast_grouped_weights_to_fp4(bf16_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
num_groups, n, k = bf16_weights.shape
|
|
w = torch.empty((num_groups, n, k // 2), device='cuda', dtype=torch.int8)
|
|
w_sf = torch.empty((num_groups, n, k // 32), device='cuda', dtype=torch.float)
|
|
for i in range(num_groups):
|
|
w[i], w_sf[i] = per_token_cast_to_fp4(bf16_weights[i], use_ue8m0=True, gran_k=32)
|
|
w_sf = deep_gemm.transform_sf_into_required_layout(w_sf, n, k, (1, 32), num_groups)
|
|
return w, w_sf
|
|
|
|
l1_weights = cast_grouped_weights_to_fp4(l1_weights)
|
|
l2_weights = cast_grouped_weights_to_fp4(l2_weights)
|
|
transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
|
|
|
|
# Run fused mega MoE
|
|
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
|
|
def run_fused():
|
|
buffer.x[:num_tokens].copy_(x[0])
|
|
buffer.x_sf[:num_tokens].copy_(x[1])
|
|
buffer.topk_idx[:num_tokens].copy_(topk_idx)
|
|
buffer.topk_weights[:num_tokens].copy_(topk_weights)
|
|
|
|
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
|
# noinspection PyTypeChecker
|
|
deep_gemm.fp8_fp4_mega_moe(
|
|
y,
|
|
transformed_l1_weights, transformed_l2_weights,
|
|
buffer,
|
|
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused,
|
|
activation_clamp=args.activation_clamp,
|
|
fast_math=bool(args.fast_math)
|
|
)
|
|
return y, cumulative_local_expert_recv_stats_fused
|
|
|
|
dist_print('Config:', once_in_node=True)
|
|
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
|
|
dist_print(f' > Hidden: {hidden}', once_in_node=True)
|
|
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
|
|
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
|
|
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
|
|
dist_print(once_in_node=True)
|
|
|
|
# Only do NCU profiling
|
|
if args.ncu_profile_only:
|
|
create_inputs()
|
|
dist_print(f'Run fused kernel:', once_in_node=True)
|
|
run_fused()
|
|
dist_print(f' > Done, exiting', once_in_node=True)
|
|
|
|
# Destroy and exit
|
|
dist.barrier()
|
|
buffer.destroy()
|
|
dist.destroy_process_group()
|
|
return
|
|
|
|
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
|
|
deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline()
|
|
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
|
|
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
|
|
ep_buffer = deep_ep.ElasticBuffer(
|
|
group,
|
|
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
|
|
num_topk=num_topk, use_fp8_dispatch=True,
|
|
explicitly_destroy=True,
|
|
allow_multiple_reduction=False,
|
|
gpu_timeout_secs=10, cpu_timeout_secs=30
|
|
) if is_legacy_loaded else None
|
|
|
|
def run_baseline():
|
|
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
|
|
x, topk_idx=topk_idx, topk_weights=topk_weights,
|
|
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline,
|
|
num_experts=num_experts, expert_alignment=alignment,
|
|
do_cpu_sync=False, do_handle_copy=False,
|
|
do_expand=True, use_tma_aligned_col_major_sf=True,
|
|
)
|
|
n = recv_x[0].size(0)
|
|
l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda')
|
|
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
|
recv_x, l1_weights, l1_y, handle.psum_num_recv_tokens_per_expert,
|
|
use_psum_layout=True, recipe=(1, 1, 32))
|
|
# noinspection PyCallingNonCallable
|
|
l1_y = tilelang_ops.swiglu_apply_weight_to_fp8(
|
|
x=l1_y,
|
|
topk_weights=recv_topk_weights,
|
|
avail_tokens=handle.psum_num_recv_tokens_per_expert[-1],
|
|
num_per_channels=32,
|
|
use_col_major_scales=True,
|
|
round_scale=True,
|
|
ue8m0_scale=True,
|
|
output_bf16=False,
|
|
clamp_value=args.activation_clamp,
|
|
fast_math=bool(args.fast_math)
|
|
)
|
|
l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device='cuda')
|
|
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
|
l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert,
|
|
use_psum_layout=True, recipe=(1, 1, 32))
|
|
return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline
|
|
|
|
# Check correctness (must be bitwise identical)
|
|
num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests
|
|
# noinspection PyBroadException
|
|
if is_legacy_loaded and num_correctness_tests > 0:
|
|
dist_print('Running correctness tests:', once_in_node=True)
|
|
for i in range(num_correctness_tests):
|
|
create_inputs()
|
|
for fused_result, baseline_result in zip(run_fused(), run_baseline()):
|
|
assert torch.equal(fused_result, baseline_result)
|
|
if (i + 1) % 100 == 0 or i == num_correctness_tests - 1:
|
|
dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True)
|
|
dist_print(once_in_node=True)
|
|
else:
|
|
create_inputs()
|
|
|
|
# Count local received tokens
|
|
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
|
|
gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \
|
|
(gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1
|
|
num_recv_tokens = (gathered_topk_idx != -1).sum().item()
|
|
|
|
# Benchmark
|
|
t_fused = bench_kineto(
|
|
run_fused, 'mega_moe',
|
|
barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(),
|
|
trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json')
|
|
t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
|
|
|
|
# TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K
|
|
safe_div = lambda a, b: float('nan') if b == 0 else a / b
|
|
tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused)
|
|
|
|
# HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes)
|
|
num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1"
|
|
num_hbm_bytes = (
|
|
num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
|
|
num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
|
|
num_recv_tokens * hidden + # L1 acts read (FP8)
|
|
num_recv_tokens * intermediate_hidden + # L1 output write (FP8)
|
|
num_recv_tokens * intermediate_hidden + # L2 acts read (FP8)
|
|
num_recv_tokens * hidden * 2 # L2 output write (BF16)
|
|
)
|
|
hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)
|
|
|
|
# NVLink bytes: dispatch pull + combine write-back
|
|
num_nvlink_bytes = num_recv_tokens * hidden * 3
|
|
nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)
|
|
|
|
# Combine reduction (serial) time approximation
|
|
t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12
|
|
|
|
# Summary
|
|
approx_factor = t_fused / (t_fused - t_reduction)
|
|
dist_print('Performance:', once_in_node=True)
|
|
dist_print(f' > EP: {rank_idx:2}/{num_ranks} | '
|
|
f'{tflops:4.0f} TFLOPS | '
|
|
f'overlap: '
|
|
f'{tflops * approx_factor:4.0f} TFLOPS, '
|
|
f'HBM {hbm_gbs * approx_factor:4.0f} GB/s, '
|
|
f'NVL {nvlink_gbs * approx_factor:3.0f} GB/s | '
|
|
f'{t_fused * 1e6:4.0f} us, '
|
|
f'reduction: {t_reduction * 1e6:4.1f} us | '
|
|
f'{safe_div(t_baseline, t_fused):.2f}x legacy')
|
|
|
|
# Exit
|
|
dist.barrier()
|
|
buffer.destroy()
|
|
ep_buffer.destroy() if is_legacy_loaded else None
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory')
|
|
|
|
# Resource settings
|
|
parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test')
|
|
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
|
|
|
|
# Model settings
|
|
parser.add_argument('--num-max-tokens-per-rank', type=int, default=8192, help='Number of maximum tokens per rank')
|
|
parser.add_argument('--num-tokens', type=int, default=0, help='Number of tokens per rank (follow max minus removed if 0)')
|
|
parser.add_argument('--num-max-removed-tokens', type=int, default=0, help='Maximum number of tokens to remove')
|
|
parser.add_argument('--hidden', type=int, default=7168, help='Hidden size')
|
|
parser.add_argument('--intermediate-hidden', type=int, default=3072, help='Intermediate hidden size')
|
|
parser.add_argument('--activation-clamp', type=float, default=10, help='Clamp value for activation')
|
|
parser.add_argument('--num-experts', type=int, default=384, help='Number of experts')
|
|
parser.add_argument('--num-topk', type=int, default=6, help='Number of expert selections')
|
|
parser.add_argument('--masked-ratio', type=float, default=0.0, help='Mask some expert selections')
|
|
parser.add_argument('--fast-math', type=int, default=1, help='Enable fast math (0 or 1, default: 1)')
|
|
|
|
# Test settings
|
|
parser.add_argument('--num-correctness-tests', type=int, default=None, help='Pressure test')
|
|
parser.add_argument('--dump-profile-traces', type=str, default='', help='Dump profiling trace JSONs')
|
|
parser.add_argument('--local-rank-idx', type=int, default=None, help='Run as single process with this local rank (e.g. for NCU prof)')
|
|
args = parser.parse_args()
|
|
|
|
# Create dump trace directories
|
|
if args.dump_profile_traces:
|
|
os.makedirs(args.dump_profile_traces, exist_ok=True)
|
|
|
|
if args.local_rank_idx is not None:
|
|
# Single-process mode: each process is launched separately (e.g. by NCU)
|
|
test(args.local_rank_idx, args.num_processes, args)
|
|
else:
|
|
# Launch tests
|
|
num_processes = args.num_processes
|
|
torch.multiprocessing.spawn(test, args=(num_processes, args), nprocs=num_processes)
|