diff --git a/dsv4/kernels/router/nvfp4_fused_router_cuda.py b/dsv4/kernels/router/nvfp4_fused_router_cuda.py new file mode 100644 index 00000000..d77d6411 --- /dev/null +++ b/dsv4/kernels/router/nvfp4_fused_router_cuda.py @@ -0,0 +1,98 @@ +"""NVFP4 Fused Router Kernel — Python wrapper for raw CUDA C++ kernel. + +Compiles nvfp4_fused_router_kernel.cuh with nvcc, loads via ctypes, +and provides a PyTorch custom op interface. +""" + +import os +import subprocess +import ctypes +import torch +import numpy as np +from pathlib import Path + +KERNEL_DIR = Path(__file__).parent +CUH_PATH = KERNEL_DIR / "nvfp4_fused_router_kernel.cuh" +DGEMM_INCLUDE = KERNEL_DIR.parent.parent.parent / "third_party" / "DeepGEMM" / "deep_gemm" / "include" + +# Cache compiled .so +_compiled_so = None + +def _get_nvcc_path(): + for p in ["/usr/local/cuda/bin/nvcc", "/opt/cuda/bin/nvcc"]: + if os.path.exists(p): + return p + return "nvcc" + +def _compile_kernel(): + global _compiled_so + if _compiled_so is not None: + return _compiled_so + + nvcc = _get_nvcc_path() + so_path = KERNEL_DIR / "_nvfp4_fused_router_kernel.so" + + # Check if already compiled and up to date + if so_path.exists(): + cuh_mtime = CUH_PATH.stat().st_mtime if CUH_PATH.exists() else 0 + so_mtime = so_path.stat().st_mtime + if so_mtime > cuh_mtime: + _compiled_so = ctypes.CDLL(str(so_path)) + return _compiled_so + + # Compile + cmd = [ + nvcc, + "-shared", "-o", str(so_path), + "-arch=sm_100a", + "-O3", + "--use_fast_math", + f"-I{DGEMM_INCLUDE}", + "-I" + str(KERNEL_DIR), + "-std=c++20", + "-Xcompiler", "-fPIC", + "--expt-relaxed-constexpr", + str(CUH_PATH), + ] + print(f"Compiling NVFP4 fused router kernel: {' '.join(cmd)}", flush=True) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode != 0: + print(f"nvcc STDERR:\n{result.stderr}", flush=True) + raise RuntimeError(f"nvcc compilation failed: {result.returncode}") + + _compiled_so = ctypes.CDLL(str(so_path)) + return _compiled_so + + +def run_nvfp4_fused_router_cuda( + x_fp4: torch.Tensor, # [M, K_packed] FP4 activation + x_sf: torch.Tensor, # [M, K_sf] UE8M0 activation scale factors + w_fp4: torch.Tensor, # [N, K_packed] FP4 weight (K-major after make_b_k_major) + w_sf: torch.Tensor, # [N, K_sf] UE8M0 weight scale factors + e_bias: torch.Tensor, # [N] float32 bias + routed_scaling_factor: float, + top_k: int, + num_experts: int = 384, + hidden_dim: int = 7168, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run the NVFP4 fused router kernel. + + Returns: + weights: [M, top_k] float32 + ids: [M, top_k] int32 + """ + M = x_fp4.shape[0] + N = num_experts + K = hidden_dim + + # Allocate output + out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=x_fp4.device) + out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=x_fp4.device) + + # For now, fall back to the 2-kernel path (NVFP4 GEMM + activation_topk) + # The raw CUDA kernel needs TMA descriptors which require careful setup + # This is the skeleton — full integration requires TMA descriptor creation + raise NotImplementedError( + "Raw CUDA fused router kernel requires TMA descriptor setup. " + "Use the 2-kernel path (Nvfp4Linear + activation_topk) for now." + ) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.cuh b/dsv4/kernels/router/nvfp4_fused_router_kernel.cuh new file mode 100644 index 00000000..db9148b3 --- /dev/null +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.cuh @@ -0,0 +1,644 @@ +#pragma once +/** + * DSV4 NVFP4 Fused Router Kernel — Raw CUDA C++ for SM100 Blackwell + * + * Single-kernel fusion: NVFP4 block-scaled GEMM (X @ W_gate) + sqrt(softplus) + top-k epilogue + * + * Warp layout (1 CTA, 256 threads): + * Warp 0 (TMA): Load X (FP4) + SFA + W (FP4) + SFB from GMEM -> SMEM + * Warp 1 (UTCCP): Transpose SFA/SFB SMEM -> TMEM via UTCCP + * Warp 2 (MMA): Issue UMMA mxf4.block_scale instructions, accumulate in TMEM + * Warps 3-7 (EPI): TMEM -> regs, sqrt(softplus) + e_bias, top-k heap, renorm, GMEM store + * + * Math (DSV4 §2.1): + * logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator in TMEM) + * act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|)) + * score = act + e_bias[e] + * ids = argtopk(score, k) + * w = (act[ids] / sum(act[ids])) * scaling + */ + +#include +#include +#include +#include + +// DeepGEMM primitives +#include +#include +#include +#include +#include +#include +#include + +namespace dsv4::router { + +// ============================================================================ +// Softplus helper (matching PyTorch behavior) +// ============================================================================ +__device__ __forceinline__ +float softplus(const float x) { + // softplus(x) = max(x, 0) + log1p(exp(-|x|)) + // Numerically stable for all x + return fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x))); +} + +// ============================================================================ +// Top-k heap element +// ============================================================================ +struct TopKEntry { + float score; + int idx; +}; + +// ============================================================================ +// NVFP4 Fused Router Kernel +// ============================================================================ +template < + uint32_t kNumExperts, // N (384 for DSV4) + uint32_t kHiddenDim, // K (7168 for DSV4) + uint32_t kTopK, // top-k (6 for DSV4) + uint32_t BLOCK_M, // tile M (1 for decode, up to 128 for prefill) + uint32_t BLOCK_N, // tile N (must be 128, 256, or 384) + uint32_t BLOCK_K, // tile K (must be 128) + uint32_t kNumStages, // pipeline stages + uint32_t kNumSMs, // number of SMs to use + // Derived constants + uint32_t SF_BLOCK_M = ((BLOCK_M + 127) / 128) * 128, + uint32_t SF_BLOCK_N = ((BLOCK_N + 127) / 128) * 128, + uint32_t kNumTMAThreads = 32, + uint32_t kNumUTCCPThreads = 32, + uint32_t kNumMMAThreads = 32, + uint32_t kNumEpilogueThreads = 160, // 5 warps for epilogue + uint32_t kNumThreads = kNumTMAThreads + kNumUTCCPThreads + kNumMMAThreads + kNumEpilogueThreads, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32 +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) +void nvfp4_fused_router_kernel( + const uint32_t shape_m, + const uint32_t shape_n, + const uint32_t shape_k, + float* __restrict__ out_weights, // [M, kTopK] FP32 weights + int32_t* __restrict__ out_ids, // [M, kTopK] int32 expert IDs + const float* __restrict__ e_bias, // [kNumExperts] bias + const float routed_scaling_factor, // scaling factor for weights + const __grid_constant__ cute::TmaDescriptor tensor_map_a, // activation (FP4, K-major) + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_a, // activation SF (UE8M0) + const __grid_constant__ cute::TmaDescriptor tensor_map_b, // weight (FP4, K-major) + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_b // weight SF (UE8M0) +) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLANK_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator1Sm; + + // ================================================================ + // Static assertions + // ================================================================ + DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K must be 128 for NVFP4 GEMM"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Epilogue threads must be multiple of 128"); + DG_STATIC_ASSERT(BLOCK_N % 16 == 0, "BLOCK_N must be multiple of 16"); + DG_STATIC_ASSERT(kTopK <= 32, "Top-k must fit in a single warp"); + + // ================================================================ + // UMMA configs + // ================================================================ + constexpr uint32_t UMMA_M = 128; // 1-CTA group + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32; // MMA instruction K + constexpr uint32_t kGranK = 32; // Scale factor granularity + + // Scale factor alignment + constexpr uint32_t kNumUTCCPAlignedElems = 128; + + // ================================================================ + // Shared memory layout + // ================================================================ + // Activation: FP4, K-major. SMEM: [LOAD_BLOCK_M, BLOCK_K] in FP4 (2 bits/elem, packed) + // Weight: FP4, K-major. SMEM: [BLOCK_N, BLOCK_K] in FP4 + using a_dtype_t = cutlass::float_e2m1_t; // FP4 activation (packed as x2) + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // = 64 for BLOCK_K=128 + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + + // Epilogue output: [BLOCK_M, kTopK] for both weights and IDs + constexpr uint32_t SMEM_OUT_SIZE = BLOCK_M * kTopK * (sizeof(float) + sizeof(int32_t)); + + // Tensor memory + constexpr uint32_t kNumEpilogueTmemStages = 2; + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueTmemStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = deep_gemm::utils::get_num_aligned_tmem_cols< + kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Total SMEM + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // A/B shared memory + auto smem_a = deep_gemm::utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = deep_gemm::utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Barriers and TMEM pointer + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]); + auto full_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto with_sf_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages * 2 + i; }); + auto tmem_full_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages * 3 + i; }); + auto tmem_empty_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages * 3 + kNumEpilogueTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueTmemStages * 2); + + // ================================================================ + // Thread indices + // ================================================================ + const uint32_t sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = deep_gemm::ptx::get_lane_idx(); + + // ================================================================ + // Prefetch TMA descriptors + // ================================================================ + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_sf_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sf_b); + } + + // ================================================================ + // Initialize barriers + // ================================================================ + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + with_sf_barriers[i]->init(32); // UTCCP warp arrives + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueTmemStages; ++i) { + tmem_full_barriers[i]->init(1); + tmem_empty_barriers[i]->init(kNumEpilogueThreads); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // ================================================================ + // Scheduler: assign M blocks to SMs + // ================================================================ + const uint32_t num_m_blocks = (shape_m + BLOCK_M - 1) / BLOCK_M; + const uint32_t num_k_blocks = (shape_k + BLOCK_K - 1) / BLOCK_K; + const uint32_t num_n_blocks = (shape_n + BLOCK_N - 1) / BLOCK_N; + + // Pipeline + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&]() { + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // ================================================================ + // Dispatch warps + // ================================================================ + if (warp_idx == 0) { + // ============================================================ + // TMA load warp + // ============================================================ + cutlass::arch::warpgroup_reg_dealloc<56>(); + + for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) { + for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) { + for (uint32_t k_block = 0; k_block < num_k_blocks; ++k_block, advance_pipeline()) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + uint32_t m_idx = m_block * BLOCK_M; + uint32_t n_idx = n_block * BLOCK_N; + uint32_t k_idx = k_block * BLOCK_K; + + if (cute::elect_one_sync()) { + // TMA load activation (A) + deep_gemm::tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + // TMA load weight (B) + deep_gemm::tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx); + + uint32_t num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + + // SFA: loaded every k stage (granularity = 32 = 1 block) + { + uint32_t sfa_m_idx = m_block * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block; + deep_gemm::tma::copy( + &tensor_map_sf_a, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx); + num_arrival_bytes += SMEM_SFA_SIZE_PER_STAGE; + } + + // SFB: loaded every k stage + { + uint32_t sfb_n_idx = n_block * SF_BLOCK_N; + uint32_t sfb_k_idx = k_block; + deep_gemm::tma::copy( + &tensor_map_sf_b, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx); + num_arrival_bytes += SMEM_SFB_SIZE_PER_STAGE; + } + + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + __syncwarp(); + } + } + } + } else if (warp_idx == 1) { + // ============================================================ + // UTCCP transposer warp + // ============================================================ + cutlass::arch::warpgroup_reg_dealloc<56>(); + + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++i) + values[i] = deep_gemm::ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++i) + deep_gemm::ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) { + for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) { + for (uint32_t k_block = 0; k_block < num_k_blocks; ++k_block, advance_pipeline()) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose SFA for UTCCP + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) { + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + } + cutlass::arch::fence_view_async_shared(); + + // Transpose SFB for UTCCP + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++i) { + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + } + cutlass::arch::fence_view_async_shared(); + + // Arrive at with_sf barrier (signals MMA warp that SF is ready) + with_sf_barriers[stage_idx]->arrive(0u); + } + } + } + } else if (warp_idx == 2) { + // ============================================================ + // MMA issue warp + // ============================================================ + cutlass::arch::warpgroup_reg_dealloc<56>(); + + // Make instruction descriptor for NVFP4 block-scaled MMA + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K>(); + auto sf_desc = deep_gemm::mma::sm100::make_sf_desc(nullptr); + + // Pre-compute UMMA descriptors for each pipeline stage + auto a_desc = deep_gemm::mma::sm100::make_umma_desc< + cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0); + auto b_desc = deep_gemm::mma::sm100::make_umma_desc< + cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? + a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? + b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + uint32_t current_iter = 0; + + for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) { + for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) { + // Wait TMEM release + auto accum_stage = current_iter % kNumEpilogueTmemStages; + auto accum_phase = (current_iter / kNumEpilogueTmemStages) & 1; + tmem_empty_barriers[accum_stage]->wait(accum_phase ^ 1); + deep_gemm::ptx::tcgen05_after_thread_sync(); + + bool first_k = true; + for (uint32_t k_block = 0; k_block < num_k_blocks; ++k_block, advance_pipeline()) { + // Wait TMA + SF ready + with_sf_barriers[stage_idx]->wait(phase); + deep_gemm::ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = deep_gemm::ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = deep_gemm::ptx::exchange(b_desc_lo, stage_idx); + + if (cute::elect_one_sync()) { + // UTCCP: copy SFA to TMEM + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + deep_gemm::mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + // UTCCP: copy SFB to TMEM + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + deep_gemm::mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + // Issue UMMA instructions + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++k) { + auto runtime_desc = deep_gemm::mma::sm100::make_runtime_instr_desc_with_sf_id( + instr_desc, k, k); + a_desc.lo = deep_gemm::mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, BLOCK_M, kSwizzleAMode, a_dtype_t>( + a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = deep_gemm::mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, BLOCK_N, kSwizzleBMode, b_dtype_t>( + b_desc_base_lo, 0, k * UMMA_K); + + deep_gemm::ptx::SM100_MMA_MXF4_SS::fma( + a_desc, b_desc, accum_stage * UMMA_N, + first_k ? 0 : 1, runtime_desc, + kTmemStartColOfSFA, kTmemStartColOfSFB); + } + } + __syncwarp(); + + // Commit + arrive at empty barrier + { + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + if (k_block == num_k_blocks - 1) { + // Signal TMEM full for epilogue + asm volatile("tcgen05.commit.cta-group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + :: "r"(cute::cast_smem_ptr_to_uint(tmem_full_barriers[accum_stage]))); + } + __syncwarp(); + } + first_k = false; + } + ++current_iter; + } + } + } else if (warp_idx >= 3) { + // ============================================================ + // Epilogue warps: TMEM -> regs -> sqrt(softplus) + top-k -> GMEM + // ============================================================ + cutlass::arch::warpgroup_reg_alloc<224>(); + + const auto epilogue_warp_idx = warp_idx - 3; + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + + // TMEM load helper + auto tmem_load_32 = [](const uint32_t tmem_addr, float* accum, const uint32_t count) { + #pragma unroll + for (uint32_t i = 0; i < count; i += 2) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_addr + i, + reinterpret_cast(accum)[i], + reinterpret_cast(accum)[i + 1]); + } + cutlass::arch::fence_view_async_tmem_load(); + }; + + uint32_t current_iter = 0; + + for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) { + for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) { + // Wait TMEM full + auto accum_stage = current_iter % kNumEpilogueTmemStages; + auto accum_phase = (current_iter / kNumEpilogueTmemStages) & 1; + tmem_full_barriers[accum_stage]->wait(accum_phase); + deep_gemm::ptx::tcgen05_after_thread_sync(); + + const uint32_t tmem_base = accum_stage * UMMA_N; + + // Each epilogue warp group processes a chunk of the N dimension + // For BLOCK_N <= 256, a single warp group (4 warps) handles all N + // Load all N values from TMEM into registers + // TMEM layout: row-major [BLOCK_M, UMMA_N] where UMMA_N = BLOCK_N + // Each TMEM row has BLOCK_N float values + // With 32 lanes and TMEM_LOAD_32dp32b32x, each lane loads 1 value per load + + const uint32_t valid_m = deep_gemm::math::min(BLOCK_M, shape_m - m_block * BLOCK_M); + const uint32_t valid_n = deep_gemm::math::min(BLOCK_N, shape_n - n_block * BLOCK_N); + + // Process each row in the M tile + for (uint32_t m = 0; m < valid_m; ++m) { + // Each warp loads a portion of N from TMEM + // TMEM address: tmem_base + m * UMMA_N + float logits[8]; // each warp holds up to 8 N values + + const uint32_t n_start = epilogue_warp_idx * 8; + const uint32_t n_count = deep_gemm::math::min(8u, valid_n - deep_gemm::math::min(n_start, valid_n)); + + // Load from TMEM + #pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + logits[i] = 0.0f; + const uint32_t n = n_start + i; + if (n < valid_n) { + uint32_t tmem_addr = tmem_base + m * UMMA_N + n; + uint32_t val; + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_addr, val); + cutlass::arch::fence_view_async_tmem_load(); + logits[i] = *reinterpret_cast(&val); + } + } + + // Signal TMEM consumed on last row + if (m == valid_m - 1) { + deep_gemm::ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage]->arrive(0u); + } + + // ======================================================== + // Epilogue: sqrt(softplus) + bias + top-k + // ======================================================== + float acts[8]; + float scores[8]; + #pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + acts[i] = sqrtf(softplus(logits[i])); + const uint32_t n = n_start + i; + scores[i] = (n < valid_n) ? acts[i] + e_bias[n] : -FLT_MAX; + } + + // Partial top-k: each warp finds its top-k, then reduce across warps + // For decode (M=1, N=384), we have 5 warps each holding ~77 values + // Simple partial sort + warp-level top-k merge + TopKEntry partial_topk[kTopK]; + #pragma unroll + for (uint32_t i = 0; i < kTopK; ++i) { + partial_topk[i] = {-FLT_MAX, -1}; + } + + // Insert each score into partial heap + #pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + const uint32_t n = n_start + i; + if (n < valid_n && scores[i] > partial_topk[kTopK - 1].score) { + partial_topk[kTopK - 1] = {scores[i], static_cast(n)}; + // Bubble up + for (int j = kTopK - 2; j >= 0 && partial_topk[j + 1].score > partial_topk[j].score; --j) { + TopKEntry tmp = partial_topk[j]; + partial_topk[j] = partial_topk[j + 1]; + partial_topk[j + 1] = tmp; + } + } + } + + // Cross-warp top-k merge using warp shuffle + // Each warp broadcasts its kTopK entries, others merge + // Final result stored by warp 0 + TopKEntry global_topk[kTopK]; + #pragma unroll + for (uint32_t i = 0; i < kTopK; ++i) { + global_topk[i] = {-FLT_MAX, -1}; + } + + // Sequential: each warp broadcasts its partial results + #pragma unroll + for (uint32_t src_warp = 0; src_warp < kNumEpilogueWarps; ++src_warp) { + #pragma unroll + for (uint32_t k = 0; k < kTopK; ++k) { + // Broadcast score and idx from src_warp + float remote_score = partial_topk[k].score; + int remote_idx = partial_topk[k].idx; + + // Shuffle across warps (use __shfl_sync for intra-warp, + // then shared memory for inter-warp) + // For simplicity, use shared memory to exchange + // TODO: optimize with warp shuffle + ballot + __syncwarp(); + } + } + + // For the initial version, let's use a simpler approach: + // Each warp writes partial top-k to SMEM, then warp 0 does final merge + // (This is the same pattern as DeepGEMM's combine reduction) + + // Store partial top-k to SMEM for cross-warp merge + extern __shared__ __align__(1024) uint8_t smem_topk_buffer[]; + // Offset after the GEMM SMEM + // Actually, by this point GEMM is done and we can reuse all SMEM + // For simplicity, overlay on the barrier area (barriers are done) + auto smem_partial = reinterpret_cast( + barrier_start_ptr + kNumStages * 3 + kNumEpilogueTmemStages * 2 + 1); + + #pragma unroll + for (uint32_t k = 0; k < kTopK; ++k) { + smem_partial[epilogue_warp_idx * kTopK + k] = partial_topk[k]; + } + __syncwarp(); + // Cross-warp sync (use named barrier) + cutlass::arch::NamedBarrier(kNumEpilogueThreads, 0).sync(); + + // Warp 0 does final top-k merge and writes to GMEM + if (epilogue_warp_idx == 0) { + TopKEntry final_topk[kTopK]; + #pragma unroll + for (uint32_t i = 0; i < kTopK; ++i) { + final_topk[i] = {-FLT_MAX, -1}; + } + + // Merge all partial top-k + #pragma unroll + for (uint32_t w = 0; w < kNumEpilogueWarps; ++w) { + #pragma unroll + for (uint32_t k = 0; k < kTopK; ++k) { + TopKEntry entry = smem_partial[w * kTopK + k]; + if (entry.score > final_topk[kTopK - 1].score) { + final_topk[kTopK - 1] = entry; + for (int j = kTopK - 2; j >= 0 && final_topk[j + 1].score > final_topk[j].score; --j) { + TopKEntry tmp = final_topk[j]; + final_topk[j] = final_topk[j + 1]; + final_topk[j + 1] = tmp; + } + } + } + } + + // Compute weights: renormalized activation values + float act_sum = 0.0f; + #pragma unroll + for (uint32_t k = 0; k < kTopK; ++k) { + // Reconstruct activation from the final expert index + // Need to load logit from TMEM again? No — we have the final_topk indices + // but we need the activation values for those indices. + // The activation is sqrt(softplus(logit)). + // We need logits for the top-k indices. + // + // Problem: we only kept scores (act + bias), not raw logits. + // Solution: store acts alongside partial_topk, or recompute from scores - bias. + // Since act = score - bias[e], and we have the bias table: + float act_k = final_topk[k].score - e_bias[final_topk[k].idx]; + act_sum += act_k; + } + + // Write output + const uint32_t out_offset = (m_block * BLOCK_M + m) * kTopK; + #pragma unroll + for (uint32_t k = 0; k < kTopK; ++k) { + float act_k = final_topk[k].score - e_bias[final_topk[k].idx]; + out_ids[out_offset + k] = final_topk[k].idx + n_block * BLOCK_N; + out_weights[out_offset + k] = (act_k / act_sum) * routed_scaling_factor; + } + } + // Sync before next m row + cutlass::arch::NamedBarrier(kNumEpilogueThreads, 0).sync(); + } + ++current_iter; + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumEpilogueThreads, 0).sync(); + if (warp_idx == 3) + Allocator().free(0, kNumTmemCols); + } + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) { + printf("nvfp4_fused_router_kernel requires sm_100a\n"); + } +#endif +} + +} // namespace dsv4::router