Quantize BF16 gate weight to NVFP4 for fused router + add global scales to GEMM
CRITICAL: Checkpoint stores gate weights as BF16, not NVFP4. Previous code fell back to BF16 cuBLAS because weight_scale was missing. Now we quantize the BF16 gate weight to NVFP4 at load time using quantize_to_nvfp4() and pass the result to the fused router kernel. Also added global scale (gsa, gsb) parameters to the kernel: - gsa (activation global scale) applied during activation quantization - gsb (weight global scale) applied in epilogue before sqrt(softplus) - The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb - Epilogue now computes sqrt(softplus(logit * gsa * gsb)) instead of sqrt(softplus(logit))
This commit is contained in:
@@ -262,7 +262,7 @@ class Nvfp4FusedRouterKernel:
|
||||
# run() — Python entry point
|
||||
# -----------------------------------------------------------------
|
||||
def run(self, mat_a, mat_b, scale_a, scale_b, mat_c,
|
||||
M, N, K, stream=None):
|
||||
M, N, K, gsa, gsb, stream=None):
|
||||
if stream is None:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
@@ -336,7 +336,7 @@ class Nvfp4FusedRouterKernel:
|
||||
self.c_smem_layout_staged,
|
||||
self.epi_tile,
|
||||
tile_sched_params,
|
||||
M, N, K,
|
||||
M, N, K, gsa, gsb,
|
||||
).launch(
|
||||
grid=grid, block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
@@ -359,7 +359,7 @@ class Nvfp4FusedRouterKernel:
|
||||
c_smem_layout_staged,
|
||||
epi_tile,
|
||||
tile_sched_params,
|
||||
M, N, K):
|
||||
M, N, K, gsa, gsb):
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
@@ -723,12 +723,15 @@ class Nvfp4FusedRouterKernel:
|
||||
acc_pipeline.consumer_release(acc_cs)
|
||||
acc_cs.advance()
|
||||
|
||||
# Activation: sqrt(softplus(logit))
|
||||
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
|
||||
# This replaces SwiGLU in the MoE epilogue
|
||||
# Activation: sqrt(softplus(logit * gsa * gsb))
|
||||
# Global scales are applied before the activation, same as
|
||||
# how MoE epilogue applies them before SwiGLU.
|
||||
# The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb.
|
||||
scale = cutlass.Float32(gsa * gsb)
|
||||
acc_vec = tTR_rAcc.load()
|
||||
for e in cutlass.range(cute.size(acc_vec), unroll=4):
|
||||
logit = acc_vec[e]
|
||||
logit = acc_vec[e] * scale
|
||||
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
|
||||
abs_x = cute.math.absf(logit)
|
||||
pos = cute.math.fmax(logit, cutlass.Float32(0.0))
|
||||
exp_neg = cute.math.exp(-abs_x)
|
||||
@@ -856,6 +859,8 @@ def run_nvfp4_fused_router(
|
||||
scale_b=cute_sfb,
|
||||
mat_c=cute_c,
|
||||
M=N, N=E, K=K,
|
||||
gsa=gsa,
|
||||
gsb=gsb_val,
|
||||
)
|
||||
|
||||
# Add e_bias (selection bias) and run top-k
|
||||
|
||||
@@ -706,11 +706,28 @@ def main():
|
||||
gate_input_scale=gate_isc.to(dev) if gate_isc is not None else torch.tensor(1.0 / (6.0 * 448.0), device=dev),
|
||||
)
|
||||
else:
|
||||
# BF16 fallback
|
||||
# BF16 gate weight: quantize to NVFP4 for fused kernel
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous()
|
||||
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
|
||||
gw = gw.bfloat16().to(dev)
|
||||
# Quantize BF16 → NVFP4 for fused router kernel
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
# gsb (weight global scale) = gw_gs from quantization
|
||||
# gsa (activation global scale) = 1.0 (applied during activation quantization inside kernel)
|
||||
# Actually: gsa is passed to quantize_activation_nvfp4 inside run_nvfp4_fused_router
|
||||
# We need to compute the correct gsa. For NVFP4, gsa = 1/(max_val * 448)
|
||||
# But since activation is quantized at runtime, gsa = input_scale from Nvfp4Linear = 1/(6*448)
|
||||
router.load_nvfp4_fused_gate(
|
||||
gate_weight=gw_fp4,
|
||||
gate_weight_scale=gw_sf,
|
||||
gate_ws2=torch.tensor(gw_gs, device=dev), # gsb = weight global scale
|
||||
gate_input_scale=torch.tensor(1.0 / (6.0 * 448.0), device=dev), # gsa = activation global scale
|
||||
)
|
||||
else:
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
|
||||
Reference in New Issue
Block a user