diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp new file mode 100644 index 0000000..d47bfb0 --- /dev/null +++ b/csrc/apis/mega_nvfp4.hpp @@ -0,0 +1,237 @@ +#pragma once + +#include +#include + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit/compiler.hpp" +#endif +#include "../jit/device_runtime.hpp" +#include "../jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp" + +namespace deep_gemm::mega::nvfp4 { + +static int get_token_alignment_for_mega_moe() { + return layout::kLCMCandidateBlockM; +} + +static std::tuple(const torch::Tensor&)>> +get_symm_buffer_size_for_nvfp4_mega_moe( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const bool& use_fp8_dispatch, const std::string& activation) { + DG_HOST_ASSERT(num_experts % num_ranks == 0); + + // Workspace bytes + const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); + + // NVFP4 layouts: group_size=16, so SF stride is K/16 (twice as many as MXFP4) + const auto fp8_token_layout = layout::Data(hidden); + const auto bf16_token_layout = layout::Data(hidden * 2); + const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); + const auto nvfp4_sf_layout = layout::Data(hidden / 16); + const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); + const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Input buffers + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_tokens_per_rank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + nvfp4_sf_layout, 1, num_max_tokens_per_rank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, num_max_tokens_per_rank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, num_max_tokens_per_rank, + input_topk_idx_buffer.get_end_ptr()); + + // Buffer configs + const auto num_max_pool_tokens = static_cast(workspace.num_max_pool_tokens); + int num_max_padded_sf_pool_tokens = 0; + for (int block_m: layout::kCandidateBlockM) { + num_max_padded_sf_pool_tokens = std::max( + num_max_padded_sf_pool_tokens, + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) + ); + } + + // L1 input buffer + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_pool_tokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + nvfp4_sf_layout, 1, num_max_padded_sf_pool_tokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, num_max_pool_tokens, + l1_sf_buffer.get_end_ptr()); + + // L2 input buffer + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, num_max_pool_tokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + nvfp4_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, + l2_token_buffer.get_end_ptr()); + + // Combine input buffer: BF16 tokens for cross-rank combine + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, num_topk, num_max_tokens_per_rank, + l2_sf_buffer.get_end_ptr()); + + // Check SF buffer requirements + // NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16) + DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0); + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); + + // Slice function + auto slice_input_buffers = [=](const torch::Tensor& buffer) { + auto x = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), + {num_max_tokens_per_rank, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 + auto x_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), + {num_max_tokens_per_rank, hidden / 64}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto topk_idx = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kInt64).device(buffer.device())); + auto topk_weights = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); + auto l1_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), + {num_max_pool_tokens, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + // NVFP4 L1 SF: M-major, K/64 int32 + auto l1_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, hidden / 64}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto l2_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), + {num_max_pool_tokens, intermediate_hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + // NVFP4 L2 SF: M-major, K/64 int32 + auto l2_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, intermediate_hidden / 64}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); + }; + return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; +} + +static void fp8_nvfp4_mega_moe( + const torch::Tensor& y, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, + const torch::Tensor& sym_buffer, + const std::vector& sym_buffer_ptrs, const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, const int& num_topk, + const std::tuple& recipe, + const std::string& activation, + const std::optional& activation_clamp_opt, + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; + + // Config checks + const auto num_tokens = static_cast(y.size(0)); + const auto [rm, rn, rk] = recipe; + DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16 + DG_HOST_ASSERT(activation == "swiglu"); + + // Activation checks + const auto activation_clamp = + activation_clamp_opt.value_or(std::numeric_limits::infinity()); + DG_HOST_ASSERT(activation_clamp >= 0); + + // Tensor checks + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = + check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major); + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = + check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major); + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); + DG_HOST_ASSERT(hidden == hidden_); + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); + + // Check weight SF layout for UE4M3 packing, MN-major, and TMA alignment + // NVFP4: kGranK=16, SF packed as int32 (4 UE4M3 bytes per int32) + constexpr int kGranMN = 1, kGranK = 16; + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, + num_experts_per_rank, true, false, torch::kInt); + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, + num_experts_per_rank, true, false, torch::kInt); + + // Check stats counter + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + + // Check buffer bytes + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts_ = num_experts_per_rank * num_ranks; + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_nvfp4_mega_moe( + num_ranks, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + true, activation); + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); + DG_HOST_ASSERT(num_experts == num_experts_); + + // Already registered tensors + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); + + // Dispatch into different architectures + if (arch_major == 10) { + sm100_fp8_nvfp4_mega_moe(y, + l1_acts, l1_acts_sf, + l2_acts, l2_acts_sf, + l1_weights, l2_weights, + l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer_ptrs, + rank_idx, num_max_tokens_per_rank, + num_experts_per_rank, + num_tokens, num_topk, + hidden, intermediate_hidden, + activation_clamp, fast_math); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Zero the entire symmetric buffer for debug mode + if (get_env("DG_COMM_KERNEL_DEBUG")) + sym_buffer.zero_(); +} + +static void register_apis(pybind11::module_& m) { +#if DG_TENSORMAP_COMPATIBLE + m.def("get_token_alignment_for_nvfp4_mega_moe", &get_token_alignment_for_mega_moe); + m.def("get_symm_buffer_size_for_nvfp4_mega_moe", &get_symm_buffer_size_for_nvfp4_mega_moe); + m.def("fp8_nvfp4_mega_moe", &fp8_nvfp4_mega_moe); +#endif +} + +} // namespace deep_gemm::mega::nvfp4 diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp new file mode 100644 index 0000000..cf91d35 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +class SM100FP8NVFP4MegaMoERuntime final : public LaunchRuntime { +public: + struct Args { + // Templated arguments + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + float activation_clamp; + bool fast_math; + MegaMoEConfig config; + + // Runtime arguments + void* y; + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + // Tensormap + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + CUtensorMap tensor_map_l1_weights_sf; + CUtensorMap tensor_map_l1_output; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + CUtensorMap tensor_map_l2_weights_sf; + + // Launch configs + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_nvfp4_mega_moe_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {} + >); +}}; +)", args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.store_block_m, + args.config.sf_block_m, args.config.sf_block_n, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.launch_args.grid_dim.first, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.tensor_map_l1_weights_sf, + args.tensor_map_l1_output, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.tensor_map_l2_weights_sf + )); + } +}; + +static void sm100_fp8_nvfp4_mega_moe( + const torch::Tensor& y, + const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, + const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, + const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, + const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, + const std::vector& sym_buffer_ptrs, + const int& rank_idx, const int& num_max_tokens_per_rank, + const int& num_experts_per_rank, + const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math +) { + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); + + // Heuristics + const auto config = get_mega_moe_config( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens); + + // NVFP4: kGranK=16 for group_size=16 + constexpr int kGranK = 16; + + // Make tensormap — weight/activation TMA descriptors are the same as MXFP4 + // (E2M1 packed uint8 is the same format regardless of scale type) + const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, + hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, + static_cast(l1_acts.stride(-2)), + config.swizzle_acts_mode); + // NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32 + const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, + config.num_padded_sf_pool_tokens, hidden, + config.sf_block_m, kGranK, + 1, 0); + const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, + hidden, num_experts_per_rank * intermediate_hidden * 2, + config.block_k, config.load_block_n, + static_cast(l1_weights.stride(-2)), + config.swizzle_weights_mode); + const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf, + intermediate_hidden * 2, hidden, + config.block_n, kGranK, + num_experts_per_rank, 0); + const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, config.store_block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode / 2); + const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, + config.num_padded_sf_pool_tokens, intermediate_hidden, + config.sf_block_m, kGranK, + 1, 0); + const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, + intermediate_hidden, num_experts_per_rank * hidden, + config.block_k, config.load_block_n, + static_cast(l2_weights.stride(-2)), + config.swizzle_weights_mode); + const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf, + hidden, intermediate_hidden, + config.block_n, kGranK, + num_experts_per_rank, 0); + + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + + // Launch + const auto num_sms = device_runtime->get_num_sms(); + const SM100FP8NVFP4MegaMoERuntime::Args args = { + .num_max_tokens_per_rank = num_max_tokens_per_rank, + .hidden = hidden, .intermediate_hidden = intermediate_hidden, + .num_experts = num_experts, .num_topk = num_topk, + .num_ranks = num_ranks, + .activation_clamp = activation_clamp, + .fast_math = fast_math, + .config = config, + .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, + .num_tokens = num_tokens, + .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), + .tensor_map_l1_acts = tensor_map_l1_acts, + .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, + .tensor_map_l1_weights = tensor_map_l1_weights, + .tensor_map_l1_weights_sf = tensor_map_l1_weights_sf, + .tensor_map_l1_output = tensor_map_l1_output, + .tensor_map_l2_acts = tensor_map_l2_acts, + .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, + .tensor_map_l2_weights = tensor_map_l2_weights, + .tensor_map_l2_weights_sf = tensor_map_l2_weights_sf, + .launch_args = LaunchArgs(num_sms, + config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, + config.smem_size, 2) + }; + + const auto code = SM100FP8NVFP4MegaMoERuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_nvfp4_mega_moe", code); + SM100FP8NVFP4MegaMoERuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index a966afe..b929e93 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -7,6 +7,7 @@ #include "apis/gemm.hpp" #include "apis/layout.hpp" #include "apis/mega.hpp" +#include "apis/mega_nvfp4.hpp" #include "apis/runtime.hpp" #ifndef TORCH_EXTENSION_NAME @@ -24,5 +25,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { deep_gemm::gemm::register_apis(m); deep_gemm::layout::register_apis(m); deep_gemm::mega::register_apis(m); + deep_gemm::mega::nvfp4::register_apis(m); deep_gemm::runtime::register_apis(m); } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh new file mode 100644 index 0000000..a2170e5 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -0,0 +1,1398 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_nvfp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + // Template checks + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + // Thread indices + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + // Workspaces + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + // Token and buffer layouts + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + // NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4) + constexpr auto fp8_sf_layout = layout::Data(kHidden / 16); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered inputs + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + // SF and its buffer configs + // NVFP4: group_size=16 → kGranK=16 (vs MXFP4's 32) + constexpr uint32_t kGranK = 16; + // For NVFP4 scale_vec::4X, UTCCP alignment is still 128 elements + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB"); + + // NVFP4 UTCCP 4x32 transpose index mapping within each 128-element group + // Same transpose as MXFP4 — the UTCCP layout is determined by the 128-element + // alignment, not the scale vector size + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + // L1 inputs + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + + // L2 inputs + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr() + ); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr() + ); + + // Combine inputs + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr() + ); + + // Data types + // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + // MMA configs + // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + + // Swizzle configs + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + // Epilogue configs + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Shared memory + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // Shared memory sizes + // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage) + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + // Tensor memory size + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + // NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row + // For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns + // SFB uses BLOCK_N/32 rows × 4 cols + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Assign shared memory for dispatch warps + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + // GEMM shared memory: C/D, A, B + // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + ); + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_cd_l2 = smem_cd[0]; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SF shared memory: SFA and SFB per pipeline stage + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Epilogue amax reduction shared memory + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + // A cluster sync is essential for 2CTA tensor memory allocation + comm::cluster_sync_with_relaxed_arrive(); + + // Initialization + if (warp_idx == 0) { + // Clean shared memory + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + // Init m-barriers for dispatch + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); + + // Task scheduler + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + // MMA pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM Barrier indices + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Adjust registers + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts) + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // Different warp roles + if (warp_idx < kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // Dispatch warps + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + // TODO: figure out better unrolling + // Now, `unroll` is better than `unroll 8` + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + // Allocate slots for each token-topk + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count experts' tokens + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Get SM offset (~6.5 us) + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source indices (~2 us with 512 tokens) + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + // Grid sync + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + // Write expert count + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Barrier before pulling + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* After the grid sync above, there is no more writes by other SMs (except 0) */ false, + /* After the NVLink barrier, there is a grid sync */ true + ); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Pull token data and SF from remote ranks into local L1 buffer + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + // Cache expert token counts in registers (same pattern as scheduler) + scheduler.fetch_expert_recv_count(); + + // Per-rank counts for current expert (re-loaded when expert changes) + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + // Advance expert until within the range + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + + // Update pool block offset for the new expert + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + + // Move start and end to the next expert + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + + // Finish all tokens + if (current_expert_idx >= kNumExpertsPerRank) + break; + + // Load per-rank counts when expert changes + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + // TODO: this is not coalesced + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection via iterative min-peeling + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + // Compute active count and min across all ranks + // NOTES: reduce within each lane first, then warp-reduce once + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + // Hit in the current round + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + + // Move into the next round + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + // Read source token-topk index (written by remote dispatch via NVLink) + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA load token from remote rank into shared memory + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Load and store SF (overlaps with TMA token load) + // NVFP4: group_size=16, 4 UE4M3 scales per uint32 + constexpr uint32_t kNumSFUint32 = kHidden / 64; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 64 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(token_idx_in_expert); + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + // Store weights and token data + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + // Load weights + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + // Wait for TMA token load to complete + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + // Store token to local L1 buffer via TMA + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + // Write source metadata for combine write-back + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + // Wait for token TMA store to complete + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Clean workspace for the next usage, and also do cumulative stats + // NOTES: it is overlapped with combine reduction epilogue + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + // SM 0: clear expert send count + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + // Other SMs: clean blocks + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + // Read expert token count before clearing + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + // Compute expert pool block offset + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + // Wait read count ready + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + // Clean per-rank token count + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + // Clean L1 and L2 arrival stuffs + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + // Wait for all ranks to finish cleaning + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* Before the NVLink barrier, there is a grid sync */ true, + /* At the end of kernel does not need to sync */ false + ); + } else if (warp_idx == kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for tokens with SFA + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u); + + // Compute pool block offset for this expert + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait the entire token arrival for linear 1 + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute token offset from pool block index + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx; + + // Add 2 CTA offsets for non-leader CTA + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + // TMA copy tokens and SFA, then arrive at full barrier + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for weights with SF + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + const auto tensor_map_sfb_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute weight offset + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; + + // TMA copy weights with SF + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM MMA issue warp (only the leader CTA will run) + if (is_leader_cta) { + // NVFP4: use float_ue4m3_t scale factor type with mxf4nvf4 instruction + // NOTES: always swap A/B + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue4m3_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K + >(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Dynamic update of UMMA N based on effective M + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + // Wait tensor memory empty barrier arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Launch MMAs + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA load completion + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // UTCCP copy SFA and SFB to TMEM + // NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + // NVFP4 4X: 8 TMEM columns per 128-element SF group + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8); + } + + // Issue UMMA + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + // NVFP4: use mxf4nvf4 instruction with UE4M3 scales + ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + // To safely deconstruct barriers, we need another round of waits + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // GEMM epilogue warps + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + // TODO: support effective block M + // NOTES: + // - 2 warpgroups divide the whole BM into BM / 2 + // - 4 warps divide the whole BN into BN / 4 + // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size + // - `STORE_BLOCK_M` in further divided into `ATOM_M` + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Wait UMMA arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + // Compute offsets + // NOTES: use shuffle here to let NVCC know warp divergence won't happen + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + if (block_phase == sched::BlockPhase::Linear1) { + // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights + // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are: + // (values[0], values[2]), (values[1], values[3]), + // (values[4], values[6]), (values[5], values[7]) + float stored_cached_weight = 0; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + // Iterate all atoms in the store block + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + // Load weights from global into register cache per 32 tokens + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + // Load weights from register cache + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + // Load from TMEM + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Signal tensor memory consumed on the last atom + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Apply SwiGLU: silu(gate) * up + // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7) + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++ k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + // Clamp + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + // SwiGLU + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + // Amax reduction + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + // Wait shared memory release from previous TMA store + // And fence `smem_amax_reduction` + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Cast to FP8 E4M3 and store into shared memory + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + // Reduce amax + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + // Calculate SF + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + // Cast + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + // STSM + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) + // Only one warp per pair writes (both hold the same SF after cross-warp reduce) + // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + // NVFP4: convert float scale to UE4M3 format + // UE4M3: sign=0 + 4 exp + 3 mantissa, max=448 + auto to_ue4m3 = [](float v) -> uint8_t { + v = fmaxf(0.0f, fminf(v, 448.0f)); + cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v); + return reinterpret_cast(e4m3_val) & 0x7F; + }; + sf_base_ptr[sf_addr] = to_ue4m3(sf.x); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = to_ue4m3(sf.y); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store after all atoms in this store block + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + // Notify L2 + // TODO: less epilogue sync scope + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx + ); + } + __syncwarp(); + } else { + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + // TODO: check performance + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) { + // Load from TMEM using .16x256b shape to satisfy STSM layout requirements + // Start from lane index 0 and 16 + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Wait shared memory release from previous NVLink store + // NOTES: skip for the first store block since the prior full barrier already ensures completion + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Signal tensor memory consumed + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Store into shared memory + // NOTES: only use first 16 lanes for address + // NOTES: 2 warps share a BF16 swizzle atom + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr + ); + } + + // Wait shared memory ready + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Write into remote buffers + // One warp per row, now the layout is different from shared memory storing + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + + // Skip padding rows beyond the actual token count for this expert + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read from shared memory + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + // Write into remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + // Ensure the next epilogue safe to use shared memory + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // Deallocate tensor memory + // NOTES: must be called by the same logical warp ID on both CTAs + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Barrier with dispatch warps, so that they can do clean workspace + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Combine: reduce top-k results and write back + // NOTES: reuse shared memory from start up to the barriers + // 1 token, 1 topk latency: ~3 us + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + // 3 slots of chunk is needed: 2 load stages and 1 store + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + + // NOTES: either 1 or 2 chunks for simplicity + // NOTES: Restrict on both smem and register + constexpr uint32_t kNumChunks = + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + // Verify combined shared memory budget at runtime + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + // Per-warp buffer: 2 stage load buffers + 1 store buffer + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + // Per-warp barriers + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + // Iterate over all tokens + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + // Read top-k slot indices: each lane reads one slot, then broadcast via exchange + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + // Iterate all chunks + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + // Move mask and load + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + // Move + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + + // Load + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + // Load the first selection + bool do_reduce = move_mask_and_load(load_stage_idx); + + // Accumulate all top-k contributions for this chunk in float registers + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + // Prefetch next top-k into the buffer while current is being accumulated + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + + // Accumulate + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + // Cast + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + // Wait share memory release and write + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + // TMA store the token chunk + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh b/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh index 528b3dd..f4ed99f 100644 --- a/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh +++ b/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh @@ -139,6 +139,50 @@ struct SM100_MMA_MXF4_SS { } }; +/// NVFP4 block-scaled MMA (scale_vec::4X = block16, UE4M3 scales) +struct SM100_MMA_MXF4NVF4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +/// NVFP4 block-scaled MMA (single CTA, scale_vec::4X = block16, UE4M3 scales) +struct SM100_MMA_MXF4NVF4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + struct SM100_MMA_F16BF16_WS_SS { CUTLASS_DEVICE static void fma(uint64_t const& desc_a, diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index e624ecf..52fbb35 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -105,6 +105,57 @@ def transform_weights_for_mega_moe( return l1_weights, l2_weights +def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: + """Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout. + + NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X). + The UTCCP layout packs 4 consecutive scale bytes into each int32, + then applies the 4x32 transpose for TMA consumption. + + Input: (num_experts, mn, K//16) float8_e4m3fn scales + Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales + """ + num_groups, mn, sf_k = sf.shape + assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}" + assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}" + + # View as uint8 and pack 4 consecutive bytes into int32 + sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k) + # Pack: every 4 uint8 → 1 int32 + packed = (sf_uint8[..., 0::4].to(torch.int32) | + (sf_uint8[..., 1::4].to(torch.int32) << 8) | + (sf_uint8[..., 2::4].to(torch.int32) << 16) | + (sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4) + + # Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined + # by the 128-element alignment, not the scale vector size) + packed_sf_k = sf_k // 4 + result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k) + .transpose(2, 3) + .reshape(num_groups, mn, packed_sf_k)) + return torch.empty_like(packed).copy_(result) + + +def transform_nvfp4_weights_for_mega_moe( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """Transform NVFP4 expert weights for the mega_moe kernel. + + NVFP4 weights come as (weight, scale) where: + - weight: uint8 E2M1 packed, shape (num_experts, N, K//2) + - scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16) + + The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout. + """ + # L1: interleave gate/up, then pack + transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) + # L2: only pack + transpose SF for UTCCP + l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights + + def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], @@ -126,3 +177,32 @@ def fp8_fp4_mega_moe(y: torch.Tensor, activation, activation_clamp, fast_math ) + + +def fp8_nvfp4_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (1, 1, 16), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + """NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X + with UE4M3 block scales (group_size=16). + + Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales) + Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16. + """ + _C.fp8_nvfp4_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + )