diff --git a/README.md b/README.md index 4a4e930e..8ed2ddb8 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUT ## Architecture -DeepSeek-V4-Pro is a 256-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 32 experts. For each token, the router picks the top-6 experts. +DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts. ### The MoE Forward Pass @@ -45,7 +45,7 @@ Input hidden states (BF16) │ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 output (7168-wide) │ -│ 6. Write to output tensor │ +│ 6. Write to output tensor │ ← caller handles cross-rank all-reduce └─────────────────────────────────────────┘ ``` @@ -59,7 +59,7 @@ Input hidden states (BF16) 2. Model construction └─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer Each layer has: attention + MoE block - MoE block has: shared experts + 256 routed experts + MoE block has: shared experts + 384 routed experts 3. Weight loading └─ 95 safetensor shards loaded @@ -137,7 +137,7 @@ nvfp4_megamoe_kernel/ ### GEMM dimensions (DeepSeek-V4-Pro) - **L1 (gate_up_proj):** M×6144×7168 (per expert) - **L2 (down_proj):** M×7168×3072 (per expert) -- 48 experts per rank (256 total / 8 ranks), top-6 routing +- 48 experts per rank (384 total / 8 ranks), top-6 routing --- diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index a5424840..81e0ea75 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -7,7 +7,7 @@ Architecture: - L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales) - SiLU+Mul activation - L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales) -- NVLink cross-rank sync via symm buffer +- NVLink cross-rank sync handled by caller (not this kernel) - Expert parallel: each rank handles NUM_EXPERTS/8 experts The kernel uses native NVFP4 block-scaled MMA via tcgen05.mma @@ -253,7 +253,7 @@ def nvfp4_mega_moe_full( 3. SiLU + Mul (activation) 4. Quantize L1 output → FP4 + UE4M3 scales 5. L2 GEMM: down_proj (native NVFP4 block-scaled MMA) - 6. NVLink sync + reduce across ranks → write to y + 6. Write to y (caller handles cross-rank all-reduce) Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1 with UE4M3 block-16 scaling in Blackwell tensor cores. @@ -305,5 +305,5 @@ def nvfp4_mega_moe_full( topk_ids, topk_weights, num_experts_per_rank, ) - # Step 6: Write to output + # Step 6: Write to output (caller handles cross-rank all-reduce) y.copy_(l2_output)