WIP: NVFP4 fused router kernel in raw CUDA C++ using DeepGEMM primitives
- nvfp4_fused_router_kernel.cuh: 1-CTA NVFP4 GEMM + sqrt(softplus) + top-k epilogue - Uses DeepGEMM SM100 primitives: SM100_MMA_MXF4_SS, UTCCP, UMMA descriptors - 4 warp roles: TMA load, UTCCP transpose, MMA issue, epilogue - nvfp4_fused_router_cuda.py: Python wrapper (TMA descriptor setup TBD) NOT YET COMPILING - needs: 1. SMEM layout fix (single extern __shared__) 2. TMA descriptor creation (cuTensorMapEncodeTiled) 3. Top-k cross-warp merge completion 4. FP4 tensor format alignment with DeepGEMM
This commit is contained in:
98
dsv4/kernels/router/nvfp4_fused_router_cuda.py
Normal file
98
dsv4/kernels/router/nvfp4_fused_router_cuda.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""NVFP4 Fused Router Kernel — Python wrapper for raw CUDA C++ kernel.
|
||||
|
||||
Compiles nvfp4_fused_router_kernel.cuh with nvcc, loads via ctypes,
|
||||
and provides a PyTorch custom op interface.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import ctypes
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
KERNEL_DIR = Path(__file__).parent
|
||||
CUH_PATH = KERNEL_DIR / "nvfp4_fused_router_kernel.cuh"
|
||||
DGEMM_INCLUDE = KERNEL_DIR.parent.parent.parent / "third_party" / "DeepGEMM" / "deep_gemm" / "include"
|
||||
|
||||
# Cache compiled .so
|
||||
_compiled_so = None
|
||||
|
||||
def _get_nvcc_path():
|
||||
for p in ["/usr/local/cuda/bin/nvcc", "/opt/cuda/bin/nvcc"]:
|
||||
if os.path.exists(p):
|
||||
return p
|
||||
return "nvcc"
|
||||
|
||||
def _compile_kernel():
|
||||
global _compiled_so
|
||||
if _compiled_so is not None:
|
||||
return _compiled_so
|
||||
|
||||
nvcc = _get_nvcc_path()
|
||||
so_path = KERNEL_DIR / "_nvfp4_fused_router_kernel.so"
|
||||
|
||||
# Check if already compiled and up to date
|
||||
if so_path.exists():
|
||||
cuh_mtime = CUH_PATH.stat().st_mtime if CUH_PATH.exists() else 0
|
||||
so_mtime = so_path.stat().st_mtime
|
||||
if so_mtime > cuh_mtime:
|
||||
_compiled_so = ctypes.CDLL(str(so_path))
|
||||
return _compiled_so
|
||||
|
||||
# Compile
|
||||
cmd = [
|
||||
nvcc,
|
||||
"-shared", "-o", str(so_path),
|
||||
"-arch=sm_100a",
|
||||
"-O3",
|
||||
"--use_fast_math",
|
||||
f"-I{DGEMM_INCLUDE}",
|
||||
"-I" + str(KERNEL_DIR),
|
||||
"-std=c++20",
|
||||
"-Xcompiler", "-fPIC",
|
||||
"--expt-relaxed-constexpr",
|
||||
str(CUH_PATH),
|
||||
]
|
||||
print(f"Compiling NVFP4 fused router kernel: {' '.join(cmd)}", flush=True)
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
print(f"nvcc STDERR:\n{result.stderr}", flush=True)
|
||||
raise RuntimeError(f"nvcc compilation failed: {result.returncode}")
|
||||
|
||||
_compiled_so = ctypes.CDLL(str(so_path))
|
||||
return _compiled_so
|
||||
|
||||
|
||||
def run_nvfp4_fused_router_cuda(
|
||||
x_fp4: torch.Tensor, # [M, K_packed] FP4 activation
|
||||
x_sf: torch.Tensor, # [M, K_sf] UE8M0 activation scale factors
|
||||
w_fp4: torch.Tensor, # [N, K_packed] FP4 weight (K-major after make_b_k_major)
|
||||
w_sf: torch.Tensor, # [N, K_sf] UE8M0 weight scale factors
|
||||
e_bias: torch.Tensor, # [N] float32 bias
|
||||
routed_scaling_factor: float,
|
||||
top_k: int,
|
||||
num_experts: int = 384,
|
||||
hidden_dim: int = 7168,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run the NVFP4 fused router kernel.
|
||||
|
||||
Returns:
|
||||
weights: [M, top_k] float32
|
||||
ids: [M, top_k] int32
|
||||
"""
|
||||
M = x_fp4.shape[0]
|
||||
N = num_experts
|
||||
K = hidden_dim
|
||||
|
||||
# Allocate output
|
||||
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=x_fp4.device)
|
||||
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=x_fp4.device)
|
||||
|
||||
# For now, fall back to the 2-kernel path (NVFP4 GEMM + activation_topk)
|
||||
# The raw CUDA kernel needs TMA descriptors which require careful setup
|
||||
# This is the skeleton — full integration requires TMA descriptor creation
|
||||
raise NotImplementedError(
|
||||
"Raw CUDA fused router kernel requires TMA descriptor setup. "
|
||||
"Use the 2-kernel path (Nvfp4Linear + activation_topk) for now."
|
||||
)
|
||||
644
dsv4/kernels/router/nvfp4_fused_router_kernel.cuh
Normal file
644
dsv4/kernels/router/nvfp4_fused_router_kernel.cuh
Normal file
@@ -0,0 +1,644 @@
|
||||
#pragma once
|
||||
/**
|
||||
* DSV4 NVFP4 Fused Router Kernel — Raw CUDA C++ for SM100 Blackwell
|
||||
*
|
||||
* Single-kernel fusion: NVFP4 block-scaled GEMM (X @ W_gate) + sqrt(softplus) + top-k epilogue
|
||||
*
|
||||
* Warp layout (1 CTA, 256 threads):
|
||||
* Warp 0 (TMA): Load X (FP4) + SFA + W (FP4) + SFB from GMEM -> SMEM
|
||||
* Warp 1 (UTCCP): Transpose SFA/SFB SMEM -> TMEM via UTCCP
|
||||
* Warp 2 (MMA): Issue UMMA mxf4.block_scale instructions, accumulate in TMEM
|
||||
* Warps 3-7 (EPI): TMEM -> regs, sqrt(softplus) + e_bias, top-k heap, renorm, GMEM store
|
||||
*
|
||||
* Math (DSV4 §2.1):
|
||||
* logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator in TMEM)
|
||||
* act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|))
|
||||
* score = act + e_bias[e]
|
||||
* ids = argtopk(score, k)
|
||||
* w = (act[ids] / sum(act[ids])) * scaling
|
||||
*/
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
|
||||
// DeepGEMM primitives
|
||||
#include <deep_gemm/common/math.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/tma_copy.cuh>
|
||||
#include <deep_gemm/mma/sm100.cuh>
|
||||
#include <deep_gemm/ptx/tcgen05.cuh>
|
||||
#include <deep_gemm/ptx/ld_st.cuh>
|
||||
#include <deep_gemm/ptx/utils.cuh>
|
||||
|
||||
namespace dsv4::router {
|
||||
|
||||
// ============================================================================
|
||||
// Softplus helper (matching PyTorch behavior)
|
||||
// ============================================================================
|
||||
__device__ __forceinline__
|
||||
float softplus(const float x) {
|
||||
// softplus(x) = max(x, 0) + log1p(exp(-|x|))
|
||||
// Numerically stable for all x
|
||||
return fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x)));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Top-k heap element
|
||||
// ============================================================================
|
||||
struct TopKEntry {
|
||||
float score;
|
||||
int idx;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// NVFP4 Fused Router Kernel
|
||||
// ============================================================================
|
||||
template <
|
||||
uint32_t kNumExperts, // N (384 for DSV4)
|
||||
uint32_t kHiddenDim, // K (7168 for DSV4)
|
||||
uint32_t kTopK, // top-k (6 for DSV4)
|
||||
uint32_t BLOCK_M, // tile M (1 for decode, up to 128 for prefill)
|
||||
uint32_t BLOCK_N, // tile N (must be 128, 256, or 384)
|
||||
uint32_t BLOCK_K, // tile K (must be 128)
|
||||
uint32_t kNumStages, // pipeline stages
|
||||
uint32_t kNumSMs, // number of SMs to use
|
||||
// Derived constants
|
||||
uint32_t SF_BLOCK_M = ((BLOCK_M + 127) / 128) * 128,
|
||||
uint32_t SF_BLOCK_N = ((BLOCK_N + 127) / 128) * 128,
|
||||
uint32_t kNumTMAThreads = 32,
|
||||
uint32_t kNumUTCCPThreads = 32,
|
||||
uint32_t kNumMMAThreads = 32,
|
||||
uint32_t kNumEpilogueThreads = 160, // 5 warps for epilogue
|
||||
uint32_t kNumThreads = kNumTMAThreads + kNumUTCCPThreads + kNumMMAThreads + kNumEpilogueThreads,
|
||||
uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32
|
||||
>
|
||||
CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1)
|
||||
void nvfp4_fused_router_kernel(
|
||||
const uint32_t shape_m,
|
||||
const uint32_t shape_n,
|
||||
const uint32_t shape_k,
|
||||
float* __restrict__ out_weights, // [M, kTopK] FP32 weights
|
||||
int32_t* __restrict__ out_ids, // [M, kTopK] int32 expert IDs
|
||||
const float* __restrict__ e_bias, // [kNumExperts] bias
|
||||
const float routed_scaling_factor, // scaling factor for weights
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a, // activation (FP4, K-major)
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_a, // activation SF (UE8M0)
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b, // weight (FP4, K-major)
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_b // weight SF (UE8M0)
|
||||
) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLANK_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
using Allocator = cute::TMEM::Allocator1Sm;
|
||||
|
||||
// ================================================================
|
||||
// Static assertions
|
||||
// ================================================================
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K must be 128 for NVFP4 GEMM");
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Epilogue threads must be multiple of 128");
|
||||
DG_STATIC_ASSERT(BLOCK_N % 16 == 0, "BLOCK_N must be multiple of 16");
|
||||
DG_STATIC_ASSERT(kTopK <= 32, "Top-k must fit in a single warp");
|
||||
|
||||
// ================================================================
|
||||
// UMMA configs
|
||||
// ================================================================
|
||||
constexpr uint32_t UMMA_M = 128; // 1-CTA group
|
||||
constexpr uint32_t UMMA_N = BLOCK_N;
|
||||
constexpr uint32_t UMMA_K = 32; // MMA instruction K
|
||||
constexpr uint32_t kGranK = 32; // Scale factor granularity
|
||||
|
||||
// Scale factor alignment
|
||||
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
||||
|
||||
// ================================================================
|
||||
// Shared memory layout
|
||||
// ================================================================
|
||||
// Activation: FP4, K-major. SMEM: [LOAD_BLOCK_M, BLOCK_K] in FP4 (2 bits/elem, packed)
|
||||
// Weight: FP4, K-major. SMEM: [BLOCK_N, BLOCK_K] in FP4
|
||||
using a_dtype_t = cutlass::float_e2m1_t; // FP4 activation (packed as x2)
|
||||
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
|
||||
|
||||
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // = 64 for BLOCK_K=128
|
||||
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t);
|
||||
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
||||
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
||||
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
||||
|
||||
// Epilogue output: [BLOCK_M, kTopK] for both weights and IDs
|
||||
constexpr uint32_t SMEM_OUT_SIZE = BLOCK_M * kTopK * (sizeof(float) + sizeof(int32_t));
|
||||
|
||||
// Tensor memory
|
||||
constexpr uint32_t kNumEpilogueTmemStages = 2;
|
||||
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueTmemStages;
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
||||
constexpr uint32_t kNumTmemCols = deep_gemm::utils::get_num_aligned_tmem_cols<
|
||||
kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
||||
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
||||
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Total SMEM
|
||||
constexpr uint32_t kSharedMemoryAlignment = 1024;
|
||||
extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[];
|
||||
|
||||
// A/B shared memory
|
||||
auto smem_a = deep_gemm::utils::PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<a_dtype_t*>(smem_buffer + i * SMEM_A_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_b = deep_gemm::utils::PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<b_dtype_t*>(smem_buffer + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
});
|
||||
|
||||
// SFA/SFB shared memory
|
||||
auto sf_start_ptr = smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
auto smem_sfa = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) {
|
||||
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_sfb = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) {
|
||||
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
||||
});
|
||||
|
||||
// Barriers and TMEM pointer
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_sfb[kNumStages]);
|
||||
auto full_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; });
|
||||
auto empty_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
|
||||
auto with_sf_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages * 2 + i; });
|
||||
auto tmem_full_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages * 3 + i; });
|
||||
auto tmem_empty_barriers = deep_gemm::utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages * 3 + kNumEpilogueTmemStages + i; });
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueTmemStages * 2);
|
||||
|
||||
// ================================================================
|
||||
// Thread indices
|
||||
// ================================================================
|
||||
const uint32_t sm_idx = blockIdx.x;
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = deep_gemm::ptx::get_lane_idx();
|
||||
|
||||
// ================================================================
|
||||
// Prefetch TMA descriptors
|
||||
// ================================================================
|
||||
if (warp_idx == 0) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sf_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sf_b);
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Initialize barriers
|
||||
// ================================================================
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(1);
|
||||
with_sf_barriers[i]->init(32); // UTCCP warp arrives
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumEpilogueTmemStages; ++i) {
|
||||
tmem_full_barriers[i]->init(1);
|
||||
tmem_empty_barriers[i]->init(kNumEpilogueThreads);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (warp_idx == 2) {
|
||||
// Allocate tensor memory
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for primary kernel completion
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
// ================================================================
|
||||
// Scheduler: assign M blocks to SMs
|
||||
// ================================================================
|
||||
const uint32_t num_m_blocks = (shape_m + BLOCK_M - 1) / BLOCK_M;
|
||||
const uint32_t num_k_blocks = (shape_k + BLOCK_K - 1) / BLOCK_K;
|
||||
const uint32_t num_n_blocks = (shape_n + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
// Pipeline
|
||||
uint32_t stage_idx = 0, phase = 0;
|
||||
auto advance_pipeline = [&]() {
|
||||
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
||||
phase ^= stage_idx == 0;
|
||||
};
|
||||
|
||||
// ================================================================
|
||||
// Dispatch warps
|
||||
// ================================================================
|
||||
if (warp_idx == 0) {
|
||||
// ============================================================
|
||||
// TMA load warp
|
||||
// ============================================================
|
||||
cutlass::arch::warpgroup_reg_dealloc<56>();
|
||||
|
||||
for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) {
|
||||
for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) {
|
||||
for (uint32_t k_block = 0; k_block < num_k_blocks; ++k_block, advance_pipeline()) {
|
||||
// Wait consumer release
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
// Compute offsets
|
||||
uint32_t m_idx = m_block * BLOCK_M;
|
||||
uint32_t n_idx = n_block * BLOCK_N;
|
||||
uint32_t k_idx = k_block * BLOCK_K;
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
// TMA load activation (A)
|
||||
deep_gemm::tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, a_dtype_t>(
|
||||
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
||||
// TMA load weight (B)
|
||||
deep_gemm::tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, b_dtype_t>(
|
||||
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx);
|
||||
|
||||
uint32_t num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
|
||||
// SFA: loaded every k stage (granularity = 32 = 1 block)
|
||||
{
|
||||
uint32_t sfa_m_idx = m_block * SF_BLOCK_M;
|
||||
uint32_t sfa_k_idx = k_block;
|
||||
deep_gemm::tma::copy<SF_BLOCK_M, 1, 0>(
|
||||
&tensor_map_sf_a, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
|
||||
num_arrival_bytes += SMEM_SFA_SIZE_PER_STAGE;
|
||||
}
|
||||
|
||||
// SFB: loaded every k stage
|
||||
{
|
||||
uint32_t sfb_n_idx = n_block * SF_BLOCK_N;
|
||||
uint32_t sfb_k_idx = k_block;
|
||||
deep_gemm::tma::copy<SF_BLOCK_N, 1, 0>(
|
||||
&tensor_map_sf_b, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
|
||||
num_arrival_bytes += SMEM_SFB_SIZE_PER_STAGE;
|
||||
}
|
||||
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_idx == 1) {
|
||||
// ============================================================
|
||||
// UTCCP transposer warp
|
||||
// ============================================================
|
||||
cutlass::arch::warpgroup_reg_dealloc<56>();
|
||||
|
||||
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
|
||||
uint32_t values[4];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++i)
|
||||
values[i] = deep_gemm::ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++i)
|
||||
deep_gemm::ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
||||
};
|
||||
|
||||
for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) {
|
||||
for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) {
|
||||
for (uint32_t k_block = 0; k_block < num_k_blocks; ++k_block, advance_pipeline()) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
|
||||
// Transpose SFA for UTCCP
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) {
|
||||
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
|
||||
}
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
|
||||
// Transpose SFB for UTCCP
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++i) {
|
||||
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
|
||||
}
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
|
||||
// Arrive at with_sf barrier (signals MMA warp that SF is ready)
|
||||
with_sf_barriers[stage_idx]->arrive(0u);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_idx == 2) {
|
||||
// ============================================================
|
||||
// MMA issue warp
|
||||
// ============================================================
|
||||
cutlass::arch::warpgroup_reg_dealloc<56>();
|
||||
|
||||
// Make instruction descriptor for NVFP4 block-scaled MMA
|
||||
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
|
||||
a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
|
||||
UMMA_M, UMMA_N,
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
||||
auto sf_desc = deep_gemm::mma::sm100::make_sf_desc(nullptr);
|
||||
|
||||
// Pre-compute UMMA descriptors for each pipeline stage
|
||||
auto a_desc = deep_gemm::mma::sm100::make_umma_desc<
|
||||
cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
||||
auto b_desc = deep_gemm::mma::sm100::make_umma_desc<
|
||||
cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ?
|
||||
a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ?
|
||||
b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
uint32_t current_iter = 0;
|
||||
|
||||
for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) {
|
||||
for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) {
|
||||
// Wait TMEM release
|
||||
auto accum_stage = current_iter % kNumEpilogueTmemStages;
|
||||
auto accum_phase = (current_iter / kNumEpilogueTmemStages) & 1;
|
||||
tmem_empty_barriers[accum_stage]->wait(accum_phase ^ 1);
|
||||
deep_gemm::ptx::tcgen05_after_thread_sync();
|
||||
|
||||
bool first_k = true;
|
||||
for (uint32_t k_block = 0; k_block < num_k_blocks; ++k_block, advance_pipeline()) {
|
||||
// Wait TMA + SF ready
|
||||
with_sf_barriers[stage_idx]->wait(phase);
|
||||
deep_gemm::ptx::tcgen05_after_thread_sync();
|
||||
|
||||
const auto a_desc_base_lo = deep_gemm::ptx::exchange(a_desc_lo, stage_idx);
|
||||
const auto b_desc_base_lo = deep_gemm::ptx::exchange(b_desc_lo, stage_idx);
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
// UTCCP: copy SFA to TMEM
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
deep_gemm::mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
// UTCCP: copy SFB to TMEM
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
deep_gemm::mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
}
|
||||
|
||||
// Issue UMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++k) {
|
||||
auto runtime_desc = deep_gemm::mma::sm100::make_runtime_instr_desc_with_sf_id(
|
||||
instr_desc, k, k);
|
||||
a_desc.lo = deep_gemm::mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, BLOCK_M, kSwizzleAMode, a_dtype_t>(
|
||||
a_desc_base_lo, 0, k * UMMA_K);
|
||||
b_desc.lo = deep_gemm::mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, BLOCK_N, kSwizzleBMode, b_dtype_t>(
|
||||
b_desc_base_lo, 0, k * UMMA_K);
|
||||
|
||||
deep_gemm::ptx::SM100_MMA_MXF4_SS::fma(
|
||||
a_desc, b_desc, accum_stage * UMMA_N,
|
||||
first_k ? 0 : 1, runtime_desc,
|
||||
kTmemStartColOfSFA, kTmemStartColOfSFB);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Commit + arrive at empty barrier
|
||||
{
|
||||
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
||||
if (k_block == num_k_blocks - 1) {
|
||||
// Signal TMEM full for epilogue
|
||||
asm volatile("tcgen05.commit.cta-group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
|
||||
:: "r"(cute::cast_smem_ptr_to_uint(tmem_full_barriers[accum_stage])));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
first_k = false;
|
||||
}
|
||||
++current_iter;
|
||||
}
|
||||
}
|
||||
} else if (warp_idx >= 3) {
|
||||
// ============================================================
|
||||
// Epilogue warps: TMEM -> regs -> sqrt(softplus) + top-k -> GMEM
|
||||
// ============================================================
|
||||
cutlass::arch::warpgroup_reg_alloc<224>();
|
||||
|
||||
const auto epilogue_warp_idx = warp_idx - 3;
|
||||
const auto epilogue_wg_idx = epilogue_warp_idx / 4;
|
||||
const auto warp_idx_in_wg = epilogue_warp_idx % 4;
|
||||
|
||||
// TMEM load helper
|
||||
auto tmem_load_32 = [](const uint32_t tmem_addr, float* accum, const uint32_t count) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < count; i += 2) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_addr + i,
|
||||
reinterpret_cast<uint32_t*>(accum)[i],
|
||||
reinterpret_cast<uint32_t*>(accum)[i + 1]);
|
||||
}
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
};
|
||||
|
||||
uint32_t current_iter = 0;
|
||||
|
||||
for (uint32_t m_block = sm_idx; m_block < num_m_blocks; m_block += kNumSMs) {
|
||||
for (uint32_t n_block = 0; n_block < num_n_blocks; ++n_block) {
|
||||
// Wait TMEM full
|
||||
auto accum_stage = current_iter % kNumEpilogueTmemStages;
|
||||
auto accum_phase = (current_iter / kNumEpilogueTmemStages) & 1;
|
||||
tmem_full_barriers[accum_stage]->wait(accum_phase);
|
||||
deep_gemm::ptx::tcgen05_after_thread_sync();
|
||||
|
||||
const uint32_t tmem_base = accum_stage * UMMA_N;
|
||||
|
||||
// Each epilogue warp group processes a chunk of the N dimension
|
||||
// For BLOCK_N <= 256, a single warp group (4 warps) handles all N
|
||||
// Load all N values from TMEM into registers
|
||||
// TMEM layout: row-major [BLOCK_M, UMMA_N] where UMMA_N = BLOCK_N
|
||||
// Each TMEM row has BLOCK_N float values
|
||||
// With 32 lanes and TMEM_LOAD_32dp32b32x, each lane loads 1 value per load
|
||||
|
||||
const uint32_t valid_m = deep_gemm::math::min(BLOCK_M, shape_m - m_block * BLOCK_M);
|
||||
const uint32_t valid_n = deep_gemm::math::min(BLOCK_N, shape_n - n_block * BLOCK_N);
|
||||
|
||||
// Process each row in the M tile
|
||||
for (uint32_t m = 0; m < valid_m; ++m) {
|
||||
// Each warp loads a portion of N from TMEM
|
||||
// TMEM address: tmem_base + m * UMMA_N
|
||||
float logits[8]; // each warp holds up to 8 N values
|
||||
|
||||
const uint32_t n_start = epilogue_warp_idx * 8;
|
||||
const uint32_t n_count = deep_gemm::math::min(8u, valid_n - deep_gemm::math::min(n_start, valid_n));
|
||||
|
||||
// Load from TMEM
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; ++i) {
|
||||
logits[i] = 0.0f;
|
||||
const uint32_t n = n_start + i;
|
||||
if (n < valid_n) {
|
||||
uint32_t tmem_addr = tmem_base + m * UMMA_N + n;
|
||||
uint32_t val;
|
||||
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_addr, val);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
logits[i] = *reinterpret_cast<float*>(&val);
|
||||
}
|
||||
}
|
||||
|
||||
// Signal TMEM consumed on last row
|
||||
if (m == valid_m - 1) {
|
||||
deep_gemm::ptx::tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage]->arrive(0u);
|
||||
}
|
||||
|
||||
// ========================================================
|
||||
// Epilogue: sqrt(softplus) + bias + top-k
|
||||
// ========================================================
|
||||
float acts[8];
|
||||
float scores[8];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; ++i) {
|
||||
acts[i] = sqrtf(softplus(logits[i]));
|
||||
const uint32_t n = n_start + i;
|
||||
scores[i] = (n < valid_n) ? acts[i] + e_bias[n] : -FLT_MAX;
|
||||
}
|
||||
|
||||
// Partial top-k: each warp finds its top-k, then reduce across warps
|
||||
// For decode (M=1, N=384), we have 5 warps each holding ~77 values
|
||||
// Simple partial sort + warp-level top-k merge
|
||||
TopKEntry partial_topk[kTopK];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kTopK; ++i) {
|
||||
partial_topk[i] = {-FLT_MAX, -1};
|
||||
}
|
||||
|
||||
// Insert each score into partial heap
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; ++i) {
|
||||
const uint32_t n = n_start + i;
|
||||
if (n < valid_n && scores[i] > partial_topk[kTopK - 1].score) {
|
||||
partial_topk[kTopK - 1] = {scores[i], static_cast<int>(n)};
|
||||
// Bubble up
|
||||
for (int j = kTopK - 2; j >= 0 && partial_topk[j + 1].score > partial_topk[j].score; --j) {
|
||||
TopKEntry tmp = partial_topk[j];
|
||||
partial_topk[j] = partial_topk[j + 1];
|
||||
partial_topk[j + 1] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cross-warp top-k merge using warp shuffle
|
||||
// Each warp broadcasts its kTopK entries, others merge
|
||||
// Final result stored by warp 0
|
||||
TopKEntry global_topk[kTopK];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kTopK; ++i) {
|
||||
global_topk[i] = {-FLT_MAX, -1};
|
||||
}
|
||||
|
||||
// Sequential: each warp broadcasts its partial results
|
||||
#pragma unroll
|
||||
for (uint32_t src_warp = 0; src_warp < kNumEpilogueWarps; ++src_warp) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kTopK; ++k) {
|
||||
// Broadcast score and idx from src_warp
|
||||
float remote_score = partial_topk[k].score;
|
||||
int remote_idx = partial_topk[k].idx;
|
||||
|
||||
// Shuffle across warps (use __shfl_sync for intra-warp,
|
||||
// then shared memory for inter-warp)
|
||||
// For simplicity, use shared memory to exchange
|
||||
// TODO: optimize with warp shuffle + ballot
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
// For the initial version, let's use a simpler approach:
|
||||
// Each warp writes partial top-k to SMEM, then warp 0 does final merge
|
||||
// (This is the same pattern as DeepGEMM's combine reduction)
|
||||
|
||||
// Store partial top-k to SMEM for cross-warp merge
|
||||
extern __shared__ __align__(1024) uint8_t smem_topk_buffer[];
|
||||
// Offset after the GEMM SMEM
|
||||
// Actually, by this point GEMM is done and we can reuse all SMEM
|
||||
// For simplicity, overlay on the barrier area (barriers are done)
|
||||
auto smem_partial = reinterpret_cast<TopKEntry*>(
|
||||
barrier_start_ptr + kNumStages * 3 + kNumEpilogueTmemStages * 2 + 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kTopK; ++k) {
|
||||
smem_partial[epilogue_warp_idx * kTopK + k] = partial_topk[k];
|
||||
}
|
||||
__syncwarp();
|
||||
// Cross-warp sync (use named barrier)
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads, 0).sync();
|
||||
|
||||
// Warp 0 does final top-k merge and writes to GMEM
|
||||
if (epilogue_warp_idx == 0) {
|
||||
TopKEntry final_topk[kTopK];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kTopK; ++i) {
|
||||
final_topk[i] = {-FLT_MAX, -1};
|
||||
}
|
||||
|
||||
// Merge all partial top-k
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumEpilogueWarps; ++w) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kTopK; ++k) {
|
||||
TopKEntry entry = smem_partial[w * kTopK + k];
|
||||
if (entry.score > final_topk[kTopK - 1].score) {
|
||||
final_topk[kTopK - 1] = entry;
|
||||
for (int j = kTopK - 2; j >= 0 && final_topk[j + 1].score > final_topk[j].score; --j) {
|
||||
TopKEntry tmp = final_topk[j];
|
||||
final_topk[j] = final_topk[j + 1];
|
||||
final_topk[j + 1] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute weights: renormalized activation values
|
||||
float act_sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kTopK; ++k) {
|
||||
// Reconstruct activation from the final expert index
|
||||
// Need to load logit from TMEM again? No — we have the final_topk indices
|
||||
// but we need the activation values for those indices.
|
||||
// The activation is sqrt(softplus(logit)).
|
||||
// We need logits for the top-k indices.
|
||||
//
|
||||
// Problem: we only kept scores (act + bias), not raw logits.
|
||||
// Solution: store acts alongside partial_topk, or recompute from scores - bias.
|
||||
// Since act = score - bias[e], and we have the bias table:
|
||||
float act_k = final_topk[k].score - e_bias[final_topk[k].idx];
|
||||
act_sum += act_k;
|
||||
}
|
||||
|
||||
// Write output
|
||||
const uint32_t out_offset = (m_block * BLOCK_M + m) * kTopK;
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kTopK; ++k) {
|
||||
float act_k = final_topk[k].score - e_bias[final_topk[k].idx];
|
||||
out_ids[out_offset + k] = final_topk[k].idx + n_block * BLOCK_N;
|
||||
out_weights[out_offset + k] = (act_k / act_sum) * routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
// Sync before next m row
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads, 0).sync();
|
||||
}
|
||||
++current_iter;
|
||||
}
|
||||
}
|
||||
|
||||
// Free tensor memory
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads, 0).sync();
|
||||
if (warp_idx == 3)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0) {
|
||||
printf("nvfp4_fused_router_kernel requires sm_100a\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace dsv4::router
|
||||
Reference in New Issue
Block a user