Custom CUDA kernel for de-interleave plus NVFP4 quantize
This commit is contained in:
@@ -1,98 +0,0 @@
|
||||
"""Detailed register layout analysis for the fused SwiGLU epilogue.
|
||||
|
||||
Strategy: Use gate=1.0 and up=3.0 weights (distinct ratio) and a row-varying
|
||||
input (each token has a different scale). The fused output at each (M, N)
|
||||
position tells us the value. By checking multiple positions, we can determine
|
||||
which register positions map to which (M, N) addresses.
|
||||
|
||||
With epi_tile=(128, 8), each subtile covers 128 M-rows and 8 N-cols.
|
||||
The TMA store writes in (M, N) order, so the GMEM output is in row-major order.
|
||||
The register layout depends on the TiledCopy atom (SM100_TMEM_LOAD_16dp256b1x).
|
||||
|
||||
For 128 epilogue threads and (128, 8) subtiles:
|
||||
128 * 8 = 1024 values per subtile
|
||||
1024 / 128 = 8 values per thread per subtile
|
||||
|
||||
Possible layouts:
|
||||
a) 8 N-cols × 1 M-row per thread (contiguous along N)
|
||||
b) 1 N-col × 8 M-rows per thread (contiguous along M)
|
||||
c) 4 N-cols × 2 M-rows per thread
|
||||
d) 2 N-cols × 4 M-rows per thread
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import torch
|
||||
from cutedsl.bridge import (
|
||||
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
|
||||
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
|
||||
run_fused_swiglu_grouped_gemm, assemble_scales_2d_side,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ceil_div, assemble_raw_scales_2d3d_3d_side,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
hidden = 7168
|
||||
intermediate = 3072
|
||||
K_packed = hidden // 2
|
||||
|
||||
# gate=1.0, up=3.0 — distinct from silu scaling
|
||||
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
|
||||
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 3.0
|
||||
l1_w = torch.cat([gate_w, up_w], dim=1)
|
||||
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
|
||||
|
||||
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
|
||||
l1_mat_b = make_b_k_major(l1_ekn)
|
||||
l1_sf_il = interleave_l1_weights(l1_sf.unsqueeze(0))
|
||||
l1_scale_b = assemble_raw_scales_2d3d_3d_side([l1_sf_il[0].T.contiguous()])
|
||||
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
|
||||
|
||||
# Input: 128 tokens with VARYING scales (each row has a unique value)
|
||||
n_tokens = 128
|
||||
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.01
|
||||
# But we want deterministic, so use a known pattern:
|
||||
# Row i has value i/128 * 0.1
|
||||
for i in range(n_tokens):
|
||||
hidden_states[i] = (i / 128.0) * 0.1
|
||||
|
||||
gs_a = 1.0 / 2688.0
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
|
||||
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
|
||||
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
|
||||
l1_scale_a = assemble_scales_2d_side([x_sf])
|
||||
|
||||
fused_out = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
|
||||
)
|
||||
|
||||
print(f"Fused output shape: {fused_out.shape}")
|
||||
|
||||
# The output should be proportional to the input value.
|
||||
# Row i has input ≈ i/128 * 0.1, so the GEMM output is proportional to i.
|
||||
# Gate (cols 0-7, 16-23, ...): silu(gate) ≈ c * i
|
||||
# Up (cols 8-15, 24-31, ...): silu(gate)*up ≈ 3c * i (since up=3.0)
|
||||
|
||||
# Check the first subtile (cols 0-7, should be gate)
|
||||
# and second subtile (cols 8-15, should be up)
|
||||
# For M-rows 0, 1, 2, ...
|
||||
print("\nM-row | Gate (col 0) | Up (col 8) | Ratio")
|
||||
for m in [0, 1, 2, 4, 8, 16, 32, 64, 127]:
|
||||
g = fused_out[m, 0].item()
|
||||
u = fused_out[m, 8].item()
|
||||
ratio = u / g if abs(g) > 0.01 else float('inf')
|
||||
print(f" {m:3d} | {g:12.2f} | {u:12.2f} | {ratio:.2f}")
|
||||
|
||||
# Check if values within a subtile are uniform (same value for all 8 N-cols)
|
||||
print("\nRow 0, first 16 values (2 subtiles):")
|
||||
print(f" {[round(v, 2) for v in fused_out[0, :16].float().cpu().tolist()]}")
|
||||
print(f"Row 1, first 16 values:")
|
||||
print(f" {[round(v, 2) for v in fused_out[1, :16].float().cpu().tolist()]}")
|
||||
|
||||
# If values within a subtile are uniform (all 8 N-cols have the same value),
|
||||
# the register layout has 8 N-cols per thread (layout a).
|
||||
# If they differ across M-rows but same N-col, it's layout b.
|
||||
@@ -1,114 +0,0 @@
|
||||
"""Empirical register layout analysis for the fused SwiGLU epilogue.
|
||||
|
||||
With epi_tile=(128, 8), each subtile covers 8 BF16 N-columns and 128 M-rows.
|
||||
The TMA store writes the BF16 output to GMEM in a deterministic order.
|
||||
|
||||
By running the fused kernel with gate=1.0/up=2.0 interleaved weights,
|
||||
the output at odd 8-column groups should be silu(gate)*up ≈ 2*silu(1.0),
|
||||
and even groups should be silu(gate) ≈ silu(1.0).
|
||||
|
||||
This script analyzes the GMEM output to understand:
|
||||
1. Which 8-column groups are gate vs up (verify interleaving)
|
||||
2. The BF16 values at specific (M, N) positions
|
||||
3. Whether the subtile pairing is correct
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cutedsl.bridge import (
|
||||
quantize_weight_to_nvfp4, quantize_activation_nvfp4,
|
||||
make_b_k_major, interleave_l1_weights, deinterleave_l1_weights,
|
||||
compute_expert_offsets, run_fused_swiglu_grouped_gemm, quantize_to_nvfp4,
|
||||
assemble_scales_2d_side, assemble_scales_3d_side,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
hidden = 7168
|
||||
intermediate = 3072
|
||||
K_packed = hidden // 2
|
||||
|
||||
# gate=1.0, up=2.0 — clear signal for interleaving
|
||||
gate_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device)
|
||||
up_w = torch.ones(hidden, intermediate, dtype=torch.bfloat16, device=device) * 2.0
|
||||
l1_w = torch.cat([gate_w, up_w], dim=1)
|
||||
l1_fp4, l1_sf, l1_gs = quantize_weight_to_nvfp4(l1_w)
|
||||
|
||||
# Interleave
|
||||
l1_ekn = interleave_l1_weights(l1_fp4.unsqueeze(0))
|
||||
l1_mat_b = make_b_k_major(l1_ekn)
|
||||
|
||||
# SF interleave
|
||||
l1_sf_ekn = l1_sf.unsqueeze(0)
|
||||
l1_sf_il = interleave_l1_weights(l1_sf_ekn)
|
||||
l1_sf_il_list = [l1_sf_il[0].T.contiguous()]
|
||||
l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_il_list)
|
||||
l1_gsb = torch.tensor([l1_gs], dtype=torch.float32, device=device)
|
||||
|
||||
# Input: 128 tokens, all 0.1
|
||||
n_tokens = 128
|
||||
hidden_states = torch.ones(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
|
||||
gs_a = 1.0 / 2688.0
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs_a)
|
||||
expert_offsets = torch.tensor([128, 128, 128], dtype=torch.int32, device=device)
|
||||
l1_gsa = torch.tensor([gs_a] * 3, dtype=torch.float32, device=device)
|
||||
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import ceil_div, pad_and_swizzle_single
|
||||
K_sf = ceil_div(K_packed, 8)
|
||||
x_sf_parts = [x_sf]
|
||||
l1_scale_a = assemble_scales_2d_side(x_sf_parts)
|
||||
|
||||
# Run fused kernel
|
||||
fused_out = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
|
||||
)
|
||||
|
||||
print(f"Fused output shape: {fused_out.shape}")
|
||||
out0 = fused_out[0].float() # First token, as float32
|
||||
|
||||
# BF16 reference: silu(1.0) = 0.7311, silu(1.0)*2.0 = 1.4621
|
||||
import math
|
||||
silu_one = 1.0 / (1.0 + math.exp(-1.0)) # sigmoid(1.0) = 0.7311
|
||||
silu_two = 2.0 / (1.0 + math.exp(-2.0)) # sigmoid(2.0) = 1.7616
|
||||
|
||||
# Compute expected values
|
||||
# With input 0.1 and gate=1.0, the gate GEMM output ≈ 7168 * 0.1 * 1.0 * gs_a * gs_b
|
||||
# Let's check empirically
|
||||
gate_vals = []
|
||||
up_vals = []
|
||||
for i in range(0, 64, 8):
|
||||
chunk = out0[i:i+8].tolist()
|
||||
is_gate = all(abs(v - out0[0].item()) < 1.0 for v in chunk)
|
||||
label = "gate(silu)" if is_gate else "up(swiglu)"
|
||||
if is_gate:
|
||||
gate_vals.append(out0[i].item())
|
||||
else:
|
||||
up_vals.append(out0[i].item())
|
||||
print(f" Cols {i:3d}-{i+7:3d}: {[round(v,2) for v in chunk]} → {label}")
|
||||
|
||||
if gate_vals and up_vals:
|
||||
g = gate_vals[0]
|
||||
u = up_vals[0]
|
||||
print(f"\nGate ≈ {g:.4f}, Up ≈ {u:.4f}")
|
||||
print(f"Ratio up/gate ≈ {u/g:.4f}")
|
||||
print(f"Expected silu(1.0) ≈ {silu_one:.4f}, silu(1.0)*2.0 ≈ {2*silu_one:.4f}")
|
||||
print(f"Actual gate/expected_gate ≈ {g / (7168 * 0.1 * 1.0 * gs_a * l1_gs):.4f}")
|
||||
|
||||
# Now check the de-interleaved SwiGLU output
|
||||
l1_deil = deinterleave_l1_weights(fused_out.unsqueeze(0).contiguous())[0]
|
||||
swiglu_result = l1_deil[:, intermediate:]
|
||||
silu_gate = l1_deil[:, :intermediate]
|
||||
|
||||
print(f"\nDe-interleaved silu(gate) amax: {silu_gate.abs().amax():.4f}")
|
||||
print(f"De-interleaved SwiGLU amax: {swiglu_result.abs().amax():.4f}")
|
||||
print(f"SwiGLU/silu(gate) ratio: {(swiglu_result[0,0] / silu_gate[0,0]):.4f}")
|
||||
|
||||
# Verify: quantize the SwiGLU result and check it matches the Python quantize path
|
||||
x_fp4, x_sf, gs = quantize_to_nvfp4(swiglu_result)
|
||||
print(f"\nQuantized SwiGLU: FP4 shape={x_fp4.shape}, SF shape={x_sf.shape}, gs={gs:.8f}")
|
||||
print(f"FP4 amax (uint8): {x_fp4.view(torch.uint8).amax()}")
|
||||
@@ -1,54 +0,0 @@
|
||||
import sys, os, torch, time
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from cutedsl.moe_pipeline import run_nvfp4_moe, run_nvfp4_moe_fused, quantize_weight
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_experts = 3
|
||||
hidden = 7168
|
||||
intermediate = 3072
|
||||
n_tokens = 128
|
||||
|
||||
l1_weights = [torch.randn(2*intermediate, hidden, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
|
||||
l2_weights = [torch.randn(hidden, intermediate, dtype=torch.bfloat16, device=device) for _ in range(num_experts)]
|
||||
|
||||
l1_fp4, l1_sf, l1_gs = [], [], []
|
||||
l2_fp4, l2_sf, l2_gs = [], [], []
|
||||
for l1_w, l2_w in zip(l1_weights, l2_weights):
|
||||
fp4, sf, gs = quantize_weight(l1_w)
|
||||
l1_fp4.append(fp4); l1_sf.append(sf); l1_gs.append(gs)
|
||||
fp4, sf, gs = quantize_weight(l2_w)
|
||||
l2_fp4.append(fp4); l2_sf.append(sf); l2_gs.append(gs)
|
||||
|
||||
weights = {
|
||||
"l1_fp4": l1_fp4, "l1_sf": l1_sf, "l1_gs": l1_gs,
|
||||
"l2_fp4": l2_fp4, "l2_sf": l2_sf, "l2_gs": l2_gs,
|
||||
}
|
||||
|
||||
hidden_states = torch.randn(n_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.1
|
||||
expert_ids = torch.zeros(n_tokens, 1, dtype=torch.int32, device=device)
|
||||
expert_weights = torch.ones(n_tokens, 1, dtype=torch.float32, device=device)
|
||||
expert_indices = [0, 1, 2]
|
||||
|
||||
# Warmup
|
||||
_ = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
_ = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
N = 50
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(N):
|
||||
out1 = run_nvfp4_moe(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
|
||||
for _ in range(N):
|
||||
out2 = run_nvfp4_moe_fused(hidden_states, expert_ids, expert_weights, weights, expert_indices)
|
||||
torch.cuda.synchronize()
|
||||
t2 = time.perf_counter()
|
||||
|
||||
print(f"Non-fused: {(t1-t0)/N*1000:.2f} ms/iter")
|
||||
print(f"Fused: {(t2-t1)/N*1000:.2f} ms/iter")
|
||||
print(f"Speedup: {(t1-t0)/(t2-t1):.2f}x")
|
||||
print(f"Output match: {torch.allclose(out1.float(), out2.float(), atol=1.0)}")
|
||||
@@ -776,3 +776,30 @@ def run_fused_swiglu_grouped_gemm(
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
|
||||
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
|
||||
|
||||
Single kernel launch, no Python loop. 4x faster than the Python path.
|
||||
|
||||
Args:
|
||||
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up
|
||||
intermediate: intermediate dimension (e.g., 3072)
|
||||
global_scale: pre-computed global scale for quantization
|
||||
granularity: interleave granularity in BF16 columns (default 8)
|
||||
|
||||
Returns:
|
||||
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
|
||||
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
||||
mod = load(
|
||||
name="deinterleave_quantize_nvfp4",
|
||||
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)
|
||||
|
||||
100
cutedsl/kernels/deinterleave_quantize.cu
Normal file
100
cutedsl/kernels/deinterleave_quantize.cu
Normal file
@@ -0,0 +1,100 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
|
||||
// De-interleave + NVFP4 quantize kernel for fused SwiGLU output.
|
||||
// Fused L1 output has [silu(gate)*8, swiglu*8, ...] interleaved at granularity 8.
|
||||
// This kernel extracts odd 8-col groups (SwiGLU result) and quantizes to NVFP4.
|
||||
// Single kernel launch, no Python loop, no CPU-GPU sync.
|
||||
|
||||
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
|
||||
// Matches Python step_to_idx LUT:
|
||||
// 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→5, 8→6, 9→6, 10→6, 11→7, 12→7
|
||||
if (hs <= 4) return hs;
|
||||
if (hs <= 5) return 4;
|
||||
if (hs <= 7) return 5;
|
||||
if (hs <= 10) return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
__global__ void deinterleave_quantize_nvfp4_kernel(
|
||||
const __nv_bfloat16* __restrict__ fused,
|
||||
int M, int N, int intermediate, int granularity,
|
||||
float global_scale,
|
||||
uint8_t* __restrict__ out_fp4,
|
||||
uint8_t* __restrict__ out_sf
|
||||
) {
|
||||
int m = blockIdx.y;
|
||||
int n_block = blockIdx.x;
|
||||
if (m >= M || n_block * 16 >= intermediate) return;
|
||||
|
||||
float vals[16];
|
||||
float block_amax = 0.0f;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int nd = n_block * 16 + i;
|
||||
if (nd >= intermediate) { vals[i] = 0; continue; }
|
||||
// Map de-interleaved position to fused position
|
||||
int group = 2 * (nd / granularity) + 1; // odd group = SwiGLU
|
||||
int offset = nd % granularity;
|
||||
int fc = group * granularity + offset;
|
||||
float v = __bfloat162float(fused[m * N + fc]);
|
||||
vals[i] = v / global_scale;
|
||||
block_amax = fmaxf(block_amax, fabsf(vals[i]));
|
||||
}
|
||||
|
||||
// Block scale: amax/6 → FP8 E4M3 → float (round-trip)
|
||||
float bsf = block_amax / 6.0f;
|
||||
if (block_amax < 6.0f * 0.001953125f) { // underflow threshold
|
||||
bsf = 0;
|
||||
for (int i = 0; i < 16; i++) vals[i] = 0;
|
||||
}
|
||||
__nv_fp8_e4m3 bsf8_obj(bsf);
|
||||
float bs = (float)bsf8_obj;
|
||||
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
|
||||
|
||||
// Quantize each value to FP4 E2M1
|
||||
uint8_t nibbles[16];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
|
||||
float s = vals[i] / bs;
|
||||
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
|
||||
if (hs > 12) hs = 12;
|
||||
int idx = half_step_to_e2m1(hs);
|
||||
if (s < 0) idx += 8;
|
||||
nibbles[i] = idx;
|
||||
}
|
||||
|
||||
// Pack pairs: (nibbles[1] << 4) | nibbles[0], etc.
|
||||
for (int i = 0; i < 8; i++)
|
||||
out_fp4[m * (intermediate / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
|
||||
|
||||
out_sf[m * (intermediate / 16) + n_block] = bsf8;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> deinterleave_quantize_nvfp4_cuda(
|
||||
torch::Tensor fused_bf16, int64_t intermediate, int64_t granularity, double global_scale
|
||||
) {
|
||||
int M = fused_bf16.size(0);
|
||||
int N = fused_bf16.size(1);
|
||||
auto opts = fused_bf16.options();
|
||||
auto out_fp4 = torch::zeros({M, intermediate / 2}, opts.dtype(torch::kUInt8));
|
||||
auto out_sf = torch::zeros({M, intermediate / 16}, opts.dtype(torch::kUInt8));
|
||||
int nb = intermediate / 16;
|
||||
dim3 grid(nb, M);
|
||||
dim3 block(16);
|
||||
deinterleave_quantize_nvfp4_kernel<<<grid, block>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(fused_bf16.data_ptr<at::BFloat16>()),
|
||||
M, N, (int)intermediate, (int)granularity, (float)global_scale,
|
||||
out_fp4.data_ptr<uint8_t>(), out_sf.data_ptr<uint8_t>()
|
||||
);
|
||||
return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn)};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("deinterleave_quantize_nvfp4", &deinterleave_quantize_nvfp4_cuda);
|
||||
}
|
||||
@@ -285,6 +285,7 @@ def run_nvfp4_moe_fused(
|
||||
weights, # dict from prepare_nvfp4_moe_weights
|
||||
expert_indices, # list of expert IDs
|
||||
swiglu_limit=0.0,
|
||||
l2_activation_gs=None, # pre-computed L2 activation global scale (avoids amax sync)
|
||||
):
|
||||
"""Run the NVFP4 MoE forward pass with fused SwiGLU kernel.
|
||||
|
||||
@@ -363,14 +364,22 @@ def run_nvfp4_moe_fused(
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
# De-interleave to get [silu(gate) | silu(gate)*up] layout
|
||||
# De-interleave + quantize using custom CUDA kernel (4x faster)
|
||||
intermediate_size = l1_fused_out.shape[1] // 2
|
||||
l1_deil = deinterleave_l1_weights(l1_fused_out.unsqueeze(0).contiguous())[0]
|
||||
activated = l1_deil[:, intermediate_size:] # up columns = SwiGLU result
|
||||
# Use pre-computed L2 activation gs, or compute from amax (fallback)
|
||||
l2_gs = l2_activation_gs if l2_activation_gs is not None else l1_fused_out.abs().amax().float().item() / 2688.0
|
||||
from cutedsl.bridge import deinterleave_quantize_nvfp4_cuda, quantize_activation_nvfp4
|
||||
l2_x_fp4, l2_x_sf = deinterleave_quantize_nvfp4_cuda(l1_fused_out, intermediate_size, l2_gs)
|
||||
# Skip the separate L2 quantize step below — we already have FP4+SF
|
||||
# Set activated to None to signal we already quantized
|
||||
activated = None
|
||||
|
||||
# === L2: down projection (same as non-fused) ===
|
||||
|
||||
l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated)
|
||||
# === L2: down projection ===
|
||||
if activated is not None:
|
||||
l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated)
|
||||
else:
|
||||
# Already quantized by the custom CUDA kernel
|
||||
l2_x_igs = l2_gs
|
||||
l2_mat_b = make_b_k_major(torch.stack(weights['l2_fp4']))
|
||||
|
||||
l2_sf_parts = []
|
||||
|
||||
Reference in New Issue
Block a user