Custom CUDA kernel for de-interleave plus NVFP4 quantize

This commit is contained in:
2026-05-20 04:39:47 +00:00
parent 7fa81e6990
commit fffb2144ae
6 changed files with 142 additions and 272 deletions

View File

@@ -1,98 +0,0 @@
"""Detailed register layout analysis for the fused SwiGLU epilogue.
Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying
input (each token has a different scale). The fused output at each (M, N)
position tells us the value. By checking multiple positions, we can determine
which register positions map to which (M, N) addresses.
With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols.
The TMA store writes in (M, N) order, so the GMEM output is in row-major order.
The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x).
For 128 epilogue threads and (128, 8) subtiles:
128 * 8 = 1024 values per subtile
1024 / 128 = 8 values per thread per subtile
Possible layouts:
a) 8 N-cols × 1 M-row per thread (contiguous along N)
b) 1 N-col × 8 M-rows per thread (contiguous along M)
c) 4 N-cols × 2 M-rows per thread
d) 2 N-cols × 4 M-rows per thread
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from cutedsl.bridge import (
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
run_fused_swiglu_grouped_gemm, assemble_scales_2d_side,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div, assemble_raw_scales_2d3d_3d_side,
)
torch.manual_seed(42)
device = "cuda"
hidden = 7168
intermediate = 3072
K_packed = hidden // 2
# gate=1.0, up=3.0 — distinct from silu scaling
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0
l1_w = torch.cat([gate_w, up_w], dim=1)
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
l1_mat_b = make_b_k_major(l1_ekn)
l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0))
l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()])
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
# Input: 128 tokens with VARYING scales (each row has a unique value)
n_tokens = 128
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01
# But we want deterministic, so use a known pattern:
# Row i has value i/128 * 0.1
for i in range(n_tokens):
hidden_states[i] = (i / 128.0) * 0.1
gs_a = 1.0 / 2688.0
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
l1_scale_a = assemble_scales_2d_side([x_sf])
fused_out = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
)
print(f"Fused output shape: {fused_out.shape}")
# The output should be proportional to the input value.
# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i.
# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i
# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0)
# Check the first subtile (cols 0-7, should be gate)
# and second subtile (cols 8-15, should be up)
# For M-rows 0, 1, 2, ...
print("\nM-row | Gate (col 0) | Up (col 8) | Ratio")
for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]:
g = fused_out[m, 0].item()
u = fused_out[m, 8].item()
ratio = u / g if abs(g) > 0.01 else float('inf')
print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}")
# Check if values within a subtile are uniform (same value for all 8 N-cols)
print("\nRow 0, first 16 values (2 subtiles):")
print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}")
print(f"Row 1, first 16 values:")
print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}")
# If values within a subtile are uniform (all 8 N-cols have the same value),
# the register layout has 8 N-cols per thread (layout a).
# If they differ across M-rows but same N-col, it's layout b.

View File

@@ -1,114 +0,0 @@
"""Empirical register layout analysis for the fused SwiGLU epilogue.
With epi_tile=(128, 8), each subtile covers 8 BF16 N-columns and 128 M-rows.
The TMA store writes the BF16 output to GMEM in a deterministic order.
By running the fused kernel with gate=1.0/up=2.0 interleaved weights,
the output at odd 8-column groups should be silu(gate)*up ≈ 2*silu(1.0),
and even groups should be silu(gate) ≈ silu(1.0).
This script analyzes the GMEM output to understand:
1. Which 8-column groups are gate vs up (verify interleaving)
2. The BF16 values at specific (M, N) positions
3. Whether the subtile pairing is correct
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn.functional as F
from cutedsl.bridge import (
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
compute_expert_offsets, run_fused_swiglu_grouped_gemm, quantize_to_nvfp4,
assemble_scales_2d_side, assemble_scales_3d_side,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side
torch.manual_seed(42)
device = "cuda"
hidden = 7168
intermediate = 3072
K_packed = hidden // 2
# gate=1.0, up=2.0 — clear signal for interleaving
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 2.0
l1_w = torch.cat([gate_w, up_w], dim=1)
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
# Interleave
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
l1_mat_b = make_b_k_major(l1_ekn)
# SF interleave
l1_sf_ekn = l1_sf.unsqueeze(0)
l1_sf_il = interleave_l1_weights(l1_sf_ekn)
l1_sf_il_list = [l1_sf_il[0].T.contiguous()]
l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_il_list)
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
# Input: 128 tokens, all 0.1
n_tokens = 128
hidden_states = torch.ones(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
gs_a = 1.0 / 2688.0
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import ceil_div, pad_and_swizzle_single
K_sf = ceil_div(K_packed, 8)
x_sf_parts = [x_sf]
l1_scale_a = assemble_scales_2d_side(x_sf_parts)
# Run fused kernel
fused_out = run_fused_swiglu_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
)
print(f"Fused output shape: {fused_out.shape}")
out0 = fused_out[0].float() # First token, as float32
# BF16 reference: silu(1.0) = 0.7311, silu(1.0)*2.0 = 1.4621
import math
silu_one = 1.0 / (1.0 + math.exp(-1.0)) # sigmoid(1.0) = 0.7311
silu_two = 2.0 / (1.0 + math.exp(-2.0)) # sigmoid(2.0) = 1.7616
# Compute expected values
# With input 0.1 and gate=1.0, the gate GEMM output ≈ 7168 * 0.1 * 1.0 * gs_a * gs_b
# Let's check empirically
gate_vals = []
up_vals = []
for i in range(0, 64, 8):
chunk = out0[i:i+8].tolist()
is_gate = all(abs(v - out0[0].item()) < 1.0 for v in chunk)
label = "gate(silu)" if is_gate else "up(swiglu)"
if is_gate:
gate_vals.append(out0[i].item())
else:
up_vals.append(out0[i].item())
print(f" Cols {i:3d}-{i+7:3d}: {[round(v,2) for v in chunk]}{label}")
if gate_vals and up_vals:
g = gate_vals[0]
u = up_vals[0]
print(f"\nGate ≈ {g:.4f}, Up ≈ {u:.4f}")
print(f"Ratio up/gate ≈ {u/g:.4f}")
print(f"Expected silu(1.0) ≈ {silu_one:.4f}, silu(1.0)*2.0 ≈ {2*silu_one:.4f}")
print(f"Actual gate/expected_gate ≈ {g / (7168 * 0.1 * 1.0 * gs_a * l1_gs):.4f}")
# Now check the de-interleaved SwiGLU output
l1_deil = deinterleave_l1_weights(fused_out.unsqueeze(0).contiguous())[0]
swiglu_result = l1_deil[:, intermediate:]
silu_gate = l1_deil[:, :intermediate]
print(f"\nDe-interleaved silu(gate) amax: {silu_gate.abs().amax():.4f}")
print(f"De-interleaved SwiGLU amax: {swiglu_result.abs().amax():.4f}")
print(f"SwiGLU/silu(gate) ratio: {(swiglu_result[0,0] / silu_gate[0,0]):.4f}")
# Verify: quantize the SwiGLU result and check it matches the Python quantize path
x_fp4, x_sf, gs = quantize_to_nvfp4(swiglu_result)
print(f"\nQuantized SwiGLU: FP4 shape={x_fp4.shape}, SF shape={x_sf.shape}, gs={gs:.8f}")
print(f"FP4 amax (uint8): {x_fp4.view(torch.uint8).amax()}")

View File

@@ -1,54 +0,0 @@
import sys, os, torch, time
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from cutedsl.moe_pipeline import run_nvfp4_moe, run_nvfp4_moe_fused, quantize_weight
torch.manual_seed(42)
device = "cuda"
num_experts = 3
hidden = 7168
intermediate = 3072
n_tokens = 128
l1_weights = [torch.randn(2*intermediate, hidden, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
l2_weights = [torch.randn(hidden, intermediate, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for l1_w, l2_w in zip(l1_weights, l2_weights):
fp4, sf, gs = quantize_weight(l1_w)
l1_fp4.append(fp4); l1_sf.append(sf); l1_gs.append(gs)
fp4, sf, gs = quantize_weight(l2_w)
l2_fp4.append(fp4); l2_sf.append(sf); l2_gs.append(gs)
weights = {
"l1_fp4": l1_fp4, "l1_sf": l1_sf, "l1_gs": l1_gs,
"l2_fp4": l2_fp4, "l2_sf": l2_sf, "l2_gs": l2_gs,
}
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
expert_ids = torch.zeros(n_tokens, 1, dtype=torch.int32, device=device)
expert_weights = torch.ones(n_tokens, 1, dtype=torch.float32, device=device)
expert_indices = [0, 1, 2]
# Warmup
_ = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
_ = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
torch.cuda.synchronize()
# Benchmark
N = 50
t0 = time.perf_counter()
for _ in range(N):
out1 = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
torch.cuda.synchronize()
t1 = time.perf_counter()
for _ in range(N):
out2 = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
torch.cuda.synchronize()
t2 = time.perf_counter()
print(f"Non-fused: {(t1-t0)/N*1000:.2f} ms/iter")
print(f"Fused: {(t2-t1)/N*1000:.2f} ms/iter")
print(f"Speedup: {(t1-t0)/(t2-t1):.2f}x")
print(f"Output match: {torch.allclose(out1.float(), out2.float(), atol=1.0)}")

View File

@@ -776,3 +776,30 @@ def run_fused_swiglu_grouped_gemm(
return out
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
Single kernel launch, no Python loop. 4x faster than the Python path.
Args:
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up
intermediate: intermediate dimension (e.g., 3072)
global_scale: pre-computed global scale for quantization
granularity: interleave granularity in BF16 columns (default 8)
Returns:
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
"""
from torch.utils.cpp_extension import load
import os
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
mod = load(
name="deinterleave_quantize_nvfp4",
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)

View File

@@ -0,0 +1,100 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cstdint>
// De-interleave + NVFP4 quantize kernel for fused SwiGLU output.
// Fused L1 output has [silu(gate)*8, swiglu*8, ...] interleaved at granularity 8.
// This kernel extracts odd 8-col groups (SwiGLU result) and quantizes to NVFP4.
// Single kernel launch, no Python loop, no CPU-GPU sync.
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
// Matches Python step_to_idx LUT:
// 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→5, 8→6, 9→6, 10→6, 11→7, 12→7
if (hs <= 4) return hs;
if (hs <= 5) return 4;
if (hs <= 7) return 5;
if (hs <= 10) return 6;
return 7;
}
__global__ void deinterleave_quantize_nvfp4_kernel(
const __nv_bfloat16* __restrict__ fused,
int M, int N, int intermediate, int granularity,
float global_scale,
uint8_t* __restrict__ out_fp4,
uint8_t* __restrict__ out_sf
) {
int m = blockIdx.y;
int n_block = blockIdx.x;
if (m >= M || n_block * 16 >= intermediate) return;
float vals[16];
float block_amax = 0.0f;
for (int i = 0; i < 16; i++) {
int nd = n_block * 16 + i;
if (nd >= intermediate) { vals[i] = 0; continue; }
// Map de-interleaved position to fused position
int group = 2 * (nd / granularity) + 1; // odd group = SwiGLU
int offset = nd % granularity;
int fc = group * granularity + offset;
float v = __bfloat162float(fused[m * N + fc]);
vals[i] = v / global_scale;
block_amax = fmaxf(block_amax, fabsf(vals[i]));
}
// Block scale: amax/6 → FP8 E4M3 → float (round-trip)
float bsf = block_amax / 6.0f;
if (block_amax < 6.0f * 0.001953125f) { // underflow threshold
bsf = 0;
for (int i = 0; i < 16; i++) vals[i] = 0;
}
__nv_fp8_e4m3 bsf8_obj(bsf);
float bs = (float)bsf8_obj;
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
// Quantize each value to FP4 E2M1
uint8_t nibbles[16];
for (int i = 0; i < 16; i++) {
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
float s = vals[i] / bs;
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
if (hs > 12) hs = 12;
int idx = half_step_to_e2m1(hs);
if (s < 0) idx += 8;
nibbles[i] = idx;
}
// Pack pairs: (nibbles[1] << 4) | nibbles[0], etc.
for (int i = 0; i < 8; i++)
out_fp4[m * (intermediate / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
out_sf[m * (intermediate / 16) + n_block] = bsf8;
}
std::tuple<torch::Tensor, torch::Tensor> deinterleave_quantize_nvfp4_cuda(
torch::Tensor fused_bf16, int64_t intermediate, int64_t granularity, double global_scale
) {
int M = fused_bf16.size(0);
int N = fused_bf16.size(1);
auto opts = fused_bf16.options();
auto out_fp4 = torch::zeros({M, intermediate / 2}, opts.dtype(torch::kUInt8));
auto out_sf = torch::zeros({M, intermediate / 16}, opts.dtype(torch::kUInt8));
int nb = intermediate / 16;
dim3 grid(nb, M);
dim3 block(16);
deinterleave_quantize_nvfp4_kernel<<<grid, block>>>(
reinterpret_cast<const __nv_bfloat16*>(fused_bf16.data_ptr<at::BFloat16>()),
M, N, (int)intermediate, (int)granularity, (float)global_scale,
out_fp4.data_ptr<uint8_t>(), out_sf.data_ptr<uint8_t>()
);
return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn)};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deinterleave_quantize_nvfp4", &deinterleave_quantize_nvfp4_cuda);
}

View File

@@ -285,6 +285,7 @@ def run_nvfp4_moe_fused(
weights, # dict from prepare_nvfp4_moe_weights
expert_indices, # list of expert IDs
swiglu_limit=0.0,
l2_activation_gs=None, # pre-computed L2 activation global scale (avoids amax sync)
):
"""Run the NVFP4 MoE forward pass with fused SwiGLU kernel.
@@ -363,14 +364,22 @@ def run_nvfp4_moe_fused(
swiglu_limit=swiglu_limit,
)
# De-interleave to get [silu(gate) | silu(gate)*up] layout
# De-interleave + quantize using custom CUDA kernel (4x faster)
intermediate_size = l1_fused_out.shape[1] // 2
l1_deil = deinterleave_l1_weights(l1_fused_out.unsqueeze(0).contiguous())[0]
activated = l1_deil[:, intermediate_size:] # up columns = SwiGLU result
# Use pre-computed L2 activation gs, or compute from amax (fallback)
l2_gs = l2_activation_gs if l2_activation_gs is not None else l1_fused_out.abs().amax().float().item() / 2688.0
from cutedsl.bridge import deinterleave_quantize_nvfp4_cuda, quantize_activation_nvfp4
l2_x_fp4, l2_x_sf = deinterleave_quantize_nvfp4_cuda(l1_fused_out, intermediate_size, l2_gs)
# Skip the separate L2 quantize step below — we already have FP4+SF
# Set activated to None to signal we already quantized
activated = None
# === L2: down projection (same as non-fused) ===
l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated)
# === L2: down projection ===
if activated is not None:
l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated)
else:
# Already quantized by the custom CUDA kernel
l2_x_igs = l2_gs
l2_mat_b = make_b_k_major(torch.stack(weights['l2_fp4']))
l2_sf_parts = []