Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo * Add Mega MoE Benchmark * Minor fix * Update --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -253,36 +253,61 @@ def test_paged_mqa_logits():
|
||||
|
||||
def enumerate_paged_mqa_logits():
|
||||
arch_major = get_arch_major()
|
||||
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
|
||||
for logits_dtype in (torch.float, torch.bfloat16):
|
||||
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
|
||||
for use_2d_context_lens, clean_logits in [(True, False)]:
|
||||
for batch_size in (256, ):
|
||||
for next_n in (1, 2, 4, 5, 6) if arch_major == 10 else (1, 2):
|
||||
for num_heads, head_dim in [(64, 128)]:
|
||||
for avg_kv in (8192, 32768):
|
||||
yield is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv
|
||||
for is_varlen in ((True, False) if arch_major == 10 else (False, )):
|
||||
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
|
||||
for logits_dtype in (torch.float, torch.bfloat16):
|
||||
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
|
||||
for use_2d_context_lens, clean_logits in [(True, False)]:
|
||||
for batch_size in (256, ):
|
||||
for next_n in ((1, ) if is_varlen else ((1, 2, 4, 5, 6) if arch_major == 10 else (1, 2))):
|
||||
for max_tokens_per_batch in ((1, 4, 10) if is_varlen else (1, )):
|
||||
for num_heads, head_dim in [(64, 128)]:
|
||||
for avg_kv in (8192, 32768):
|
||||
yield is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv
|
||||
|
||||
|
||||
print('Testing FP8/FP4 Paged MQA Logits:')
|
||||
max_model_len = 111 * 1024
|
||||
num_total_blocks = max_model_len * 5
|
||||
|
||||
for is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
|
||||
for is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
|
||||
# Varlen: flatten raw_batch_size sequences with variable tokens into (batch_size, 1, ...)
|
||||
raw_batch_size, raw_next_n = batch_size, next_n
|
||||
if is_varlen:
|
||||
tokens_per_seq = torch.randint(1, max_tokens_per_batch + 1, (raw_batch_size,), device='cuda', dtype=torch.int)
|
||||
indices = torch.arange(raw_batch_size, device='cuda', dtype=torch.int).repeat_interleave(tokens_per_seq)
|
||||
batch_size, next_n = tokens_per_seq.sum().item(), 1
|
||||
else:
|
||||
tokens_per_seq, indices = None, None
|
||||
|
||||
# Generate random inputs
|
||||
q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16)
|
||||
kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16)
|
||||
weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float)
|
||||
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size,), device='cuda', dtype=torch.int)
|
||||
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (raw_batch_size,), device='cuda', dtype=torch.int)
|
||||
|
||||
# Assign block tables
|
||||
num_blocks_per_query = ceil_div(context_lens, block_kv)
|
||||
block_table = torch.empty((batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
|
||||
if is_varlen:
|
||||
max_ctx_len_per_seq = context_lens + (tokens_per_seq - 1)
|
||||
else:
|
||||
max_ctx_len_per_seq = context_lens
|
||||
|
||||
# Assign block tables (per-sequence, sized by the largest ctx_len within the sequence)
|
||||
seq_sum_lens = context_lens.sum().item()
|
||||
num_blocks_per_query = ceil_div(max_ctx_len_per_seq, block_kv)
|
||||
block_table = torch.empty((raw_batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
|
||||
block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int)
|
||||
offset = 0
|
||||
for i, num_blocks in enumerate(num_blocks_per_query.tolist()):
|
||||
block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks]
|
||||
offset += num_blocks
|
||||
if is_varlen:
|
||||
context_lens = context_lens.repeat_interleave(tokens_per_seq)
|
||||
offsets_within_seq = torch.cat([
|
||||
torch.arange(n.item(), device='cuda', dtype=torch.int)
|
||||
for n in tokens_per_seq
|
||||
])
|
||||
context_lens = context_lens + offsets_within_seq
|
||||
block_table = block_table.repeat_interleave(tokens_per_seq, dim=0)
|
||||
|
||||
# Calculate reference logits
|
||||
ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens)
|
||||
@@ -304,9 +329,14 @@ def test_paged_mqa_logits():
|
||||
# Prepare masks and context lengths with NextN
|
||||
positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1)
|
||||
if use_2d_context_lens:
|
||||
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
|
||||
# Ensure last token matches actual length
|
||||
context_lens_nextn[:, -1] = context_lens
|
||||
if is_varlen:
|
||||
# Varlen: context_lens is already per-token (shape [total_tokens]);
|
||||
# just reshape to (total_tokens, 1) so each token keeps its own ctx_len.
|
||||
context_lens_nextn = context_lens.view(-1, 1)
|
||||
else:
|
||||
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
|
||||
# Ensure last token matches actual length
|
||||
context_lens_nextn[:, -1] = context_lens
|
||||
ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1))
|
||||
else:
|
||||
context_lens_nextn = context_lens
|
||||
@@ -318,8 +348,9 @@ def test_paged_mqa_logits():
|
||||
kernel_kwargs = dict(
|
||||
q=q_in, kv_cache=kv_in, weights=weights,
|
||||
context_lens=context_lens_nextn, block_table=block_table,
|
||||
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms()),
|
||||
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype
|
||||
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms(), indices=indices),
|
||||
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype,
|
||||
indices=indices,
|
||||
)
|
||||
logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs)
|
||||
|
||||
@@ -342,11 +373,15 @@ def test_paged_mqa_logits():
|
||||
sum_lens = context_lens.sum().item()
|
||||
tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12
|
||||
kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4
|
||||
total_bytes = count_bytes(q, weights) + sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
|
||||
# KV is read once per sequence; for varlen sum_lens overcounts (per-token), so use seq_sum_lens
|
||||
kv_sum_lens = seq_sum_lens if is_varlen else sum_lens
|
||||
total_bytes = count_bytes(q, weights) + kv_sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
|
||||
|
||||
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits'))
|
||||
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={batch_size:3}, NextN={next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
|
||||
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={raw_batch_size:3}, NextN={raw_next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
|
||||
f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='')
|
||||
if is_varlen:
|
||||
print(f' | Varlen, MaxTPB={max_tokens_per_batch}, NumTokens={batch_size}', end='')
|
||||
print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '')
|
||||
print()
|
||||
|
||||
|
||||
@@ -9,24 +9,28 @@ from typing import Tuple
|
||||
import deep_gemm
|
||||
from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8
|
||||
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
|
||||
from deep_gemm.testing import bench, bench_kineto, calc_diff
|
||||
from deep_gemm.testing import bench_kineto
|
||||
|
||||
# Load legacy implements from third-party
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import deep_ep
|
||||
import importlib.util
|
||||
from tilelang.profiler.bench import do_bench
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'tilelang_ops',
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
|
||||
tilelang_ops = importlib.util.module_from_spec(spec)
|
||||
sys.modules['tilelang_ops'] = tilelang_ops
|
||||
spec.loader.exec_module(tilelang_ops)
|
||||
is_legacy_loaded = True
|
||||
except Exception as ex:
|
||||
print(f'Failed to load legacy code: {ex}, skip baseline benchmarking')
|
||||
is_legacy_loaded = False
|
||||
|
||||
def import_baseline():
|
||||
# Load legacy implements from third-party
|
||||
deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import deep_ep
|
||||
import importlib.util
|
||||
from tilelang.profiler.bench import do_bench
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'tilelang_ops',
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
|
||||
tilelang_ops = importlib.util.module_from_spec(spec)
|
||||
sys.modules['tilelang_ops'] = tilelang_ops
|
||||
spec.loader.exec_module(tilelang_ops)
|
||||
is_legacy_loaded = True
|
||||
except Exception as ex:
|
||||
dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
return deep_ep, tilelang_ops, do_bench, is_legacy_loaded
|
||||
|
||||
|
||||
# TODO: skip the test for SM90
|
||||
@@ -51,29 +55,13 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden
|
||||
)
|
||||
dist_print('Config:', once_in_node=True)
|
||||
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
|
||||
dist_print(f' > Hidden: {hidden}', once_in_node=True)
|
||||
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
|
||||
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
|
||||
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
|
||||
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
|
||||
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
|
||||
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
|
||||
ep_buffer = deep_ep.ElasticBuffer(
|
||||
group,
|
||||
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
|
||||
num_topk=num_topk, use_fp8_dispatch=True,
|
||||
explicitly_destroy=True,
|
||||
allow_multiple_reduction=False,
|
||||
gpu_timeout_secs=10, cpu_timeout_secs=30
|
||||
) if is_legacy_loaded else None
|
||||
|
||||
# Create inputs
|
||||
# noinspection PyGlobalUndefined
|
||||
def create_inputs():
|
||||
global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights
|
||||
global cumulative_local_expert_recv_stats_fused
|
||||
global cumulative_local_expert_recv_stats_baseline
|
||||
x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
l1_weights = torch.randn(
|
||||
(num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
@@ -81,6 +69,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
(num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda')
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda')
|
||||
topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)
|
||||
cumulative_local_expert_recv_stats_fused = torch.randint(
|
||||
0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda')
|
||||
cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone()
|
||||
if args.masked_ratio > 0:
|
||||
rand_mask = torch.rand_like(topk_idx, dtype=torch.float)
|
||||
topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1)
|
||||
@@ -109,12 +100,67 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
l2_weights = cast_grouped_weights_to_fp4(l2_weights)
|
||||
transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
|
||||
|
||||
# Run fused mega MoE
|
||||
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
|
||||
def run_fused():
|
||||
buffer.x[:num_tokens].copy_(x[0])
|
||||
buffer.x_sf[:num_tokens].copy_(x[1])
|
||||
buffer.topk_idx[:num_tokens].copy_(topk_idx)
|
||||
buffer.topk_weights[:num_tokens].copy_(topk_weights)
|
||||
|
||||
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
# noinspection PyTypeChecker
|
||||
deep_gemm.fp8_fp4_mega_moe(
|
||||
y,
|
||||
transformed_l1_weights, transformed_l2_weights,
|
||||
buffer,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused,
|
||||
activation_clamp=args.activation_clamp,
|
||||
fast_math=bool(args.fast_math)
|
||||
)
|
||||
return y, cumulative_local_expert_recv_stats_fused
|
||||
|
||||
dist_print('Config:', once_in_node=True)
|
||||
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
|
||||
dist_print(f' > Hidden: {hidden}', once_in_node=True)
|
||||
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
|
||||
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
|
||||
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
|
||||
# Only do NCU profiling
|
||||
if args.ncu_profile_only:
|
||||
create_inputs()
|
||||
dist_print(f'Run fused kernel:', once_in_node=True)
|
||||
run_fused()
|
||||
dist_print(f' > Done, exiting', once_in_node=True)
|
||||
|
||||
# Destroy and exit
|
||||
dist.barrier()
|
||||
buffer.destroy()
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
|
||||
deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline()
|
||||
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
|
||||
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
|
||||
ep_buffer = deep_ep.ElasticBuffer(
|
||||
group,
|
||||
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
|
||||
num_topk=num_topk, use_fp8_dispatch=True,
|
||||
explicitly_destroy=True,
|
||||
allow_multiple_reduction=False,
|
||||
gpu_timeout_secs=10, cpu_timeout_secs=30
|
||||
) if is_legacy_loaded else None
|
||||
|
||||
def run_baseline():
|
||||
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
|
||||
x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline,
|
||||
num_experts=num_experts, expert_alignment=alignment,
|
||||
do_cpu_sync=False, do_handle_copy=False,
|
||||
do_expand=True, use_tma_aligned_col_major_sf=True
|
||||
do_expand=True, use_tma_aligned_col_major_sf=True,
|
||||
)
|
||||
n = recv_x[0].size(0)
|
||||
l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda')
|
||||
@@ -138,26 +184,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
||||
l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert,
|
||||
use_psum_layout=True, recipe=(1, 1, 32))
|
||||
return ep_buffer.combine(l2_y, handle=handle)[0]
|
||||
|
||||
# Run fused mega MoE
|
||||
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
|
||||
def run_fused():
|
||||
buffer.x[:num_tokens].copy_(x[0])
|
||||
buffer.x_sf[:num_tokens].copy_(x[1])
|
||||
buffer.topk_idx[:num_tokens].copy_(topk_idx)
|
||||
buffer.topk_weights[:num_tokens].copy_(topk_weights)
|
||||
|
||||
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
# noinspection PyTypeChecker
|
||||
deep_gemm.fp8_fp4_mega_moe(
|
||||
y,
|
||||
transformed_l1_weights, transformed_l2_weights,
|
||||
buffer,
|
||||
activation_clamp=args.activation_clamp,
|
||||
fast_math=bool(args.fast_math)
|
||||
)
|
||||
return y
|
||||
return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline
|
||||
|
||||
# Check correctness (must be bitwise identical)
|
||||
num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests
|
||||
@@ -166,34 +193,36 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
dist_print('Running correctness tests:', once_in_node=True)
|
||||
for i in range(num_correctness_tests):
|
||||
create_inputs()
|
||||
assert torch.equal(run_fused(), run_baseline())
|
||||
for fused_result, baseline_result in zip(run_fused(), run_baseline()):
|
||||
assert torch.equal(fused_result, baseline_result)
|
||||
if (i + 1) % 100 == 0 or i == num_correctness_tests - 1:
|
||||
dist_print(f' > Correctness test #{i + 1}/{args.num_correctness_tests} passed', once_in_node=True)
|
||||
dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
else:
|
||||
create_inputs()
|
||||
|
||||
# Count local received tokens
|
||||
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
|
||||
num_recv_tokens = (rank_idx * num_experts_per_rank <= gathered_topk_idx) & \
|
||||
(gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)
|
||||
num_recv_tokens = num_recv_tokens.sum().item()
|
||||
gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \
|
||||
(gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1
|
||||
num_recv_tokens = (gathered_topk_idx != -1).sum().item()
|
||||
|
||||
# Benchmark
|
||||
t_fused = bench_kineto(
|
||||
run_fused, 'mega_moe',
|
||||
barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(),
|
||||
trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json')
|
||||
t_baseline = do_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
|
||||
t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
|
||||
|
||||
# TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K
|
||||
safe_div = lambda a, b: float('nan') if b == 0 else a / b
|
||||
tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused)
|
||||
|
||||
# HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes)
|
||||
num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1"
|
||||
num_hbm_bytes = (
|
||||
num_experts_per_rank * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
|
||||
num_experts_per_rank * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
|
||||
num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
|
||||
num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
|
||||
num_recv_tokens * hidden + # L1 acts read (FP8)
|
||||
num_recv_tokens * intermediate_hidden + # L1 output write (FP8)
|
||||
num_recv_tokens * intermediate_hidden + # L2 acts read (FP8)
|
||||
@@ -230,7 +259,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory')
|
||||
|
||||
# Resource settings
|
||||
parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test')
|
||||
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
|
||||
|
||||
# Model settings
|
||||
|
||||
Reference in New Issue
Block a user