CRITICAL REWRITE of nvfp4_fused_router_kernel.py: - REMOVED: Raw pointer SMEM merge (storage.merge_scores.data_ptr()[idx] = val) This crashed the CuTeDSL MLIR optimizer. Never use raw pointer indexing inside CuTeDSL kernels. - REMOVED: Per-thread top-k accumulation + 128-thread SMEM merge. Too complex for MLIR, caused SIGABRT during compilation. - ADDED: MoE-style epilogue (TMEM→regs→activation→SMEM→TMA store→GMEM) using paired copy atoms from CUTLASS (epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition). Structurally identical to the proven FusedSwiGLUScaledGroupedGemmKernel epilogue. This SHOULD compile. - Activation: sqrt(softplus(logit)) in registers (replaces SwiGLU) - Output: FP32 activated scores written to GMEM via TMA store - Top-k handled by activation_topk CUDA kernel in Python wrapper Other changes: - _activation_topk.py: Added run_fused_activation_topk_pre_activated() for top-k + renorm on pre-activated scores (PyTorch reference, not CUDA kernel) - dense_router_dispatch_nvfp4_fused: Updated to match new kernel API - Kernel now uses standard _compute_stages() for SMEM budget calculation - Kernel now uses compute_epilogue_tile_shape() for epi_tile (not hardcoded) - C pipeline (PipelineTmaStore) added for SMEM→GMEM overlap
102 lines
4.2 KiB
Python
102 lines
4.2 KiB
Python
"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k.
|
|
|
|
Production paths (in priority order):
|
|
1. NVFP4 fused router kernel (nvfp4_fused_router_kernel.py):
|
|
Single-kernel blockscaled GEMM + fused router epilogue.
|
|
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
|
|
2. NVFP4 GEMM + activation_topk (2-kernel path):
|
|
Nvfp4Linear (Blackwell tensor cores) + fused activation_topk CUDA kernel.
|
|
3. BF16 cuBLAS fallback: When NVFP4 scales are not available in the
|
|
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
|
|
(cuBLAS, SM100 tensor cores) instead.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
from typing import Tuple, Optional
|
|
import torch
|
|
|
|
|
|
def dense_router_dispatch(
|
|
hidden_states: torch.Tensor, # [N, hidden_size] BF16
|
|
W_gate: torch.Tensor, # [hidden_size, num_experts] BF16
|
|
e_bias: torch.Tensor, # [num_experts] FP32
|
|
routed_scaling_factor: float,
|
|
top_k: int,
|
|
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
|
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
|
):
|
|
"""Dispatch the dense router (BF16 cuBLAS fallback).
|
|
|
|
BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores),
|
|
then fused activation + top-k via the CUDA kernel.
|
|
"""
|
|
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
|
|
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
|
run_fused_activation_topk(
|
|
logits, e_bias, routed_scaling_factor, top_k,
|
|
out_weights, out_ids,
|
|
)
|
|
|
|
|
|
def dense_router_dispatch_nvfp4(
|
|
hidden_states: torch.Tensor, # [N, hidden_size] BF16
|
|
gate_lin, # Nvfp4Linear instance
|
|
e_bias: torch.Tensor, # [num_experts] FP32
|
|
routed_scaling_factor: float,
|
|
top_k: int,
|
|
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
|
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
|
):
|
|
"""Dispatch the dense router (NVFP4 production GEMM, 2-kernel path).
|
|
|
|
NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores),
|
|
then fused activation + top-k via the CUDA kernel.
|
|
"""
|
|
logits = gate_lin(hidden_states).float() # (N, E) FP32
|
|
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
|
run_fused_activation_topk(
|
|
logits, e_bias, routed_scaling_factor, top_k,
|
|
out_weights, out_ids,
|
|
)
|
|
|
|
|
|
def dense_router_dispatch_nvfp4_fused(
|
|
hidden_states: torch.Tensor, # [N, hidden_size] BF16
|
|
gate_weight: torch.Tensor, # [K_packed, E_packed] uint8 NVFP4 weight
|
|
gate_weight_scale: torch.Tensor, # [K_sf, E_sf] FP8 E4M3 weight scale
|
|
gate_ws2: torch.Tensor, # weight_scale_2 (scalar or per-output)
|
|
gate_input_scale: torch.Tensor, # input_scale (activation global scale base)
|
|
e_bias: torch.Tensor, # [num_experts] FP32
|
|
routed_scaling_factor: float,
|
|
top_k: int,
|
|
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
|
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
|
):
|
|
"""Dispatch the dense router (NVFP4 fused single-kernel path).
|
|
|
|
Phase 1: CuTeDSL NVFP4 blockscaled GEMM + sqrt(softplus) epilogue.
|
|
Activation is quantized to NVFP4, GEMM runs on Blackwell tensor cores,
|
|
sqrt(softplus) is fused in the epilogue (TMEM→regs→activation→SMEM→GMEM).
|
|
Writes FP32 activated scores to GMEM. No intermediate BF16 logits.
|
|
|
|
Phase 2: top-k + renorm on activated scores.
|
|
"""
|
|
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
|
|
|
|
gsa = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
|
|
gsb_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
|
|
|
|
result_w, result_ids = run_nvfp4_fused_router(
|
|
hidden_states=hidden_states,
|
|
mat_b=gate_weight,
|
|
scale_b=gate_weight_scale,
|
|
gsa=gsa,
|
|
gsb_val=gsb_val,
|
|
e_bias=e_bias,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
top_k=top_k,
|
|
)
|
|
N = hidden_states.shape[0]
|
|
out_weights[:N].copy_(result_w[:N])
|
|
out_ids[:N].copy_(result_ids[:N])
|