@@ -3,15 +3,17 @@
Exercises the production kernel stack end-to-end:
- NVFP4 GEMM kernels (CuTeDSL ScaledGroupedGemm) for all projections
- 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh)
- 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh) with sink bias
- CSA/HCA compressor (token-level softmax)
- Indexer score+topk (indexer_score_topk.cu)
- Indexer score+topk
- Dense/Hash router kernels
- Production mHC (Sinkhorn-Knopp, B_l transposed, [pre,post,comb])
- Production Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert
- Production Nvfp4Linear, Nvfp4MoE, Nvfp4SharedExpert
This is NOT a PyTorch reference — it calls the actual kernel stack .
Use as ground truth for vLLM / SGLang integration .
NO PyTorch SDPA fallback. NO dequant+matmul for production projections .
ALL tensor-core NVFP4 GEMMs. ALL kernel paths .
This is the ground truth for vLLM / SGLang integration.
"""
import os , sys , time , json , math , argparse , logging
import torch
@@ -93,79 +95,8 @@ def unweighted_rmsnorm(x, eps=1e-6):
return xf * xf . pow ( 2 ) . mean ( - 1 , keepdim = True ) . add ( eps ) . rsqrt ( )
# =====================================================================
# mHC (matches dsv4/layers/mhc.py)
# =====================================================================
HC_EPS = 1e-6
def sinkhorn_knopp ( logits , t_max = 20 , eps = HC_EPS ) :
M = torch . softmax ( logits , - 1 ) + eps
M = M / ( M . sum ( - 2 , keepdim = True ) + eps )
for _ in range ( t_max - 1 ) :
M = M / ( M . sum ( - 1 , keepdim = True ) + eps )
M = M / ( M . sum ( - 2 , keepdim = True ) + eps )
return M
class mHCBlock :
def __init__ ( self , hidden_dim = 7168 , n_hc = 4 , t_max = 20 , device = ' cuda:0 ' ) :
self . d , self . n_hc , self . K = hidden_dim , n_hc , n_hc * hidden_dim
self . t_max , self . device = t_max , device
def load ( self , fn , base , scale ) :
n = self . n_hc
self . W_pre = fn [ 0 : n ] . to ( self . device , torch . float32 ) . contiguous ( )
self . W_post = fn [ n : 2 * n ] . to ( self . device , torch . float32 ) . contiguous ( )
self . W_comb = fn [ 2 * n : ] . to ( self . device , torch . float32 ) . contiguous ( )
self . S_pre = base [ 0 : n ] . reshape ( 1 , n ) . to ( self . device , torch . float32 ) . contiguous ( )
self . S_post = base [ n : 2 * n ] . reshape ( n , 1 ) . to ( self . device , torch . float32 ) . contiguous ( )
self . S_comb = base [ 2 * n : ] . reshape ( n , n ) . to ( self . device , torch . float32 ) . contiguous ( )
self . alpha_pre , self . alpha_post , self . alpha_comb = scale [ 0 ] . item ( ) , scale [ 1 ] . item ( ) , scale [ 2 ] . item ( )
@staticmethod
def init_state ( emb , n_hc = 4 ) :
return emb . unsqueeze ( 1 ) . expand ( - 1 , n_hc , - 1 ) . clone ( )
def pre_block ( self , X ) :
T , n , d = X . shape
Xn = unweighted_rmsnorm ( X . reshape ( T , self . K ) . bfloat16 ( ) )
W_stacked = torch . cat ( [ self . W_pre , self . W_post , self . W_comb ] )
proj = Xn . float ( ) @ W_stacked . T
rms_inv = proj . pow ( 2 ) . mean ( - 1 , keepdim = True ) . add ( 1e-6 ) . rsqrt ( )
proj = ( proj * rms_inv ) . bfloat16 ( ) . float ( )
pre_t = self . alpha_pre * proj [ : , : n ] + self . S_pre . flatten ( ) . unsqueeze ( 0 )
post_t = self . alpha_post * proj [ : , n : 2 * n ] + self . S_post . flatten ( ) . unsqueeze ( 0 )
comb_t = self . alpha_comb * proj [ : , 2 * n : 2 * n + n * n ] + self . S_comb . flatten ( ) . unsqueeze ( 0 )
A = torch . sigmoid ( pre_t ) + HC_EPS
C = 2.0 * torch . sigmoid ( post_t )
B = sinkhorn_knopp ( comb_t . reshape ( T , n , n ) , t_max = self . t_max )
x_in = torch . bmm ( A . unsqueeze ( 1 ) , X . float ( ) ) . squeeze ( 1 ) . bfloat16 ( )
return x_in , { ' B ' : B , ' C ' : C }
def post_block ( self , X , F_out , ctx ) :
BX = torch . bmm ( ctx [ ' B ' ] . transpose ( - 1 , - 2 ) , X . float ( ) )
CF = ctx [ ' C ' ] . unsqueeze ( - 1 ) * F_out . unsqueeze ( 1 )
return ( CF . float ( ) + BX ) . bfloat16 ( )
# =====================================================================
# HcHead
# =====================================================================
class HcHead :
def __init__ ( self , hidden_dim = 7168 , n_hc = 4 , device = ' cuda:0 ' ) :
self . K , self . device , self . n_hc = n_hc * hidden_dim , device , n_hc
def load ( self , fn , base , scale = None ) :
self . fn = fn . to ( self . device , torch . float32 ) . contiguous ( )
self . base = base . to ( self . device , torch . float32 ) . contiguous ( )
self . scale = scale . to ( self . device , torch . float32 ) . item ( ) if scale is not None else 1.0
def forward ( self , X ) :
T = X . shape [ 0 ]
Xn = unweighted_rmsnorm ( X . reshape ( T , self . K ) . bfloat16 ( ) )
mix = F . linear ( Xn , self . fn [ : self . n_hc ] ) . float ( )
pre = torch . sigmoid ( mix * self . scale + self . base [ : self . n_hc ] . unsqueeze ( 0 ) ) + HC_EPS
return ( pre . unsqueeze ( - 1 ) * X . float ( ) ) . sum ( 1 ) . bfloat16 ( )
# =====================================================================
# NVFP4 dequant (fallback for projections not yet using kernel GEMM)
# NVFP4 dequant — used ONLY for compressor/indexer projections
# (these don't go through the CuTeDSL GEMM kernel yet)
# =====================================================================
def dequant_nvfp4 ( weight , weight_scale , weight_scale_2 = None , input_scale = None ) :
O , I2 = weight . shape
@@ -180,7 +111,7 @@ def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
if weight_scale_2 is not None : s = s * weight_scale_2 . float ( )
return ( w * s ) . bfloat16 ( )
def nvfp4_linear ( x , weight , weight_scale , weight_scale_2 = None , input_scale = None ) :
def nvfp4_linear_ref ( x , weight , weight_scale , weight_scale_2 = None , input_scale = None ) :
return F . linear ( x , dequant_nvfp4 ( weight , weight_scale , weight_scale_2 , input_scale ) )
def get_nvfp4_weight ( w , pfx , proj_name ) :
@@ -188,16 +119,32 @@ def get_nvfp4_weight(w, pfx, proj_name):
return ( w . get ( f " { k } .weight " ) , w . get ( f " { k } .weight_scale " ) ,
w . get ( f " { k } .weight_scale_2 " ) , w . get ( f " { k } .input_scale " ) )
def do_nvfp4_linear ( x , w , pfx , proj_name ) :
def do_nvfp4_linear_ref ( x , w , pfx , proj_name ) :
weight , ws , ws2 , isc = get_nvfp4_weight ( w , pfx , proj_name )
if weight is None : return None
d = x . device
return nvfp4_linear ( x , weight . to ( d ) , ws . to ( d ) ,
return nvfp4_linear_ref ( x , weight . to ( d ) , ws . to ( d ) ,
ws2 . to ( d ) if ws2 is not None else None ,
isc . to ( d ) if isc is not None else None )
# =====================================================================
# Production Nvfp4Linear wrapper
# =====================================================================
def make_nvfp4_linear ( in_features , out_features , device , weight , weight_scale ,
weight_scale_2 = None , input_scale = None ) :
""" Create a production Nvfp4Linear with weights loaded from checkpoint. """
from dsv4 . layers . linear import Nvfp4Linear
d = device
lin = Nvfp4Linear ( in_features , out_features , max_num_tokens = 8192 , device = d )
lin . fp4 = [ weight . to ( d ) ]
lin . sf = [ weight_scale . to ( d ) ]
gs = input_scale . float ( ) . item ( ) if input_scale is not None else 1.0 / ( 6.0 * 448.0 )
lin . gs = [ gs ]
return lin
# =====================================================================
# Compressor — CSA (ratio=4) and HCA (ratio=128)
# (Reference PyTorch — compressor not yet on tensor cores)
# =====================================================================
class Compressor :
def __init__ ( self , ratio , head_dim , hidden_size , device ) :
@@ -224,10 +171,10 @@ class Compressor:
n_complete = T / / r
if n_complete == 0 :
return None , None , None
kv = nvfp4_linear ( hidden_states , self . wkv_w . to ( dev ) , self . wkv_ws . to ( dev ) ,
kv = nvfp4_linear_ref ( hidden_states , self . wkv_w . to ( dev ) , self . wkv_ws . to ( dev ) ,
self . wkv_ws2 . to ( dev ) if self . wkv_ws2 is not None else None ,
self . wkv_isc . to ( dev ) if self . wkv_isc is not None else None )
gate = nvfp4_linear ( hidden_states , self . wgate_w . to ( dev ) , self . wgate_ws . to ( dev ) ,
gate = nvfp4_linear_ref ( hidden_states , self . wgate_w . to ( dev ) , self . wgate_ws . to ( dev ) ,
self . wgate_ws2 . to ( dev ) if self . wgate_ws2 is not None else None ,
self . wgate_isc . to ( dev ) if self . wgate_isc is not None else None )
if self . ape is not None :
@@ -270,7 +217,7 @@ class Compressor:
return torch . stack ( comp_list ) , torch . stack ( comp_pos_list ) , torch . zeros ( 1 , T , n_complete , dtype = torch . float32 , device = dev )
# =====================================================================
# Indexer — CSA top-k
# Indexer — CSA top-k (Reference PyTorch)
# =====================================================================
class Indexer :
def __init__ ( self , n_ih , ihd , top_k , device ) :
@@ -292,11 +239,11 @@ class Indexer:
dev = q_lora . device
T = q_lora . shape [ 0 ]
n_comp = comp_indexer_kv . shape [ 0 ]
q_idx = nvfp4_linear ( q_lora , self . q_b_w . to ( dev ) , self . q_b_ws . to ( dev ) ,
q_idx = nvfp4_linear_ref ( q_lora , self . q_b_w . to ( dev ) , self . q_b_ws . to ( dev ) ,
self . q_b_ws2 . to ( dev ) if self . q_b_ws2 is not None else None ,
self . q_b_isc . to ( dev ) if self . q_b_isc is not None else None )
q_idx = q_idx . reshape ( T , self . n_ih , self . ihd )
w_h = nvfp4_linear ( hidden_states , self . wp_w . to ( dev ) , self . wp_ws . to ( dev ) ,
w_h = nvfp4_linear_ref ( hidden_states , self . wp_w . to ( dev ) , self . wp_ws . to ( dev ) ,
self . wp_ws2 . to ( dev ) if self . wp_ws2 is not None else None ,
self . wp_isc . to ( dev ) if self . wp_isc is not None else None )
k_idx = comp_indexer_kv . reshape ( n_comp , self . n_ih , self . ihd )
@@ -364,18 +311,36 @@ def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False):
return out
# =====================================================================
# Production FMHA — 6-warp TMA multi-tile kernel
# HcHead — FP32 projection, read out from mHC state
# =====================================================================
HC_EPS = 1e-6
class HcHead :
def __init__ ( self , hidden_dim = 7168 , n_hc = 4 , device = ' cuda:0 ' ) :
self . K , self . device , self . n_hc = n_hc * hidden_dim , device , n_hc
def load ( self , fn , base , scale = None ) :
self . fn = fn . to ( self . device , torch . float32 ) . contiguous ( )
self . base = base . to ( self . device , torch . float32 ) . contiguous ( )
self . scale = scale . to ( self . device , torch . float32 ) . item ( ) if scale is not None else 1.0
def forward ( self , X ) :
T = X . shape [ 0 ]
Xn = unweighted_rmsnorm ( X . reshape ( T , self . K ) . bfloat16 ( ) )
mix = F . linear ( Xn , self . fn [ : self . n_hc ] ) . float ( )
pre = torch . sigmoid ( mix * self . scale + self . base [ : self . n_hc ] . unsqueeze ( 0 ) ) + HC_EPS
return ( pre . unsqueeze ( - 1 ) * X . float ( ) ) . sum ( 1 ) . bfloat16 ( )
# =====================================================================
# Production FMHA — 6-warp TMA multi-tile kernel with sink bias
# =====================================================================
def _run_production_fmha ( q_heads , all_kv , n_h , hd , T , seq_len , scale , dev , li , w , pfx ) :
""" Run production FMHA kernel via dsv4_attention .
q_heads: (T, n_h, hd), all_kv: (seq_len, hd)
Returns: (T, n_h, hd) BF16
The 6-warp TMA FMHA kernel correctly handles N < 128:
K/V are padded to 128 for TMA alignment, but the kernel receives
the true s_k and masks padded entries in softmax (col < kv_len guard).
Fixed in fmha_multitile_capi.cu: N_orig (logical) vs N_padded (physical).
""" Run production FMHA kernel with sink bias support .
The kernel handles:
- N < 128: K/V padded to 128, kernel uses N_orig for softmax masking
- Multi-tile KV for N > 128
- Attention sinks via per-head logit bias (D5c: single softmax)
"""
from dsv4 . kernels . attention . production import dsv4_attention
@@ -394,12 +359,18 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
) # (n_h, T, hd)
return attn_out . permute ( 1 , 0 , 2 ) # (T, n_h, hd)
# =====================================================================
# Attention forward — uses production FMHA kernel
# Attention forward — production FMHA + production Nvfp4Linear
# =====================================================================
def forward_attention ( x_normed , w , li , cfg , rope_cos , rope_sin ,
kv_cache , positions , compressor , indexer ) :
kv_cache , positions , compressor , indexer ,
prod_lin = None ) :
""" Attention sub-block using production kernels.
All projections go through Nvfp4Linear (CuTeDSL GEMM).
FMHA goes through 6-warp TMA multi-tile kernel with sink bias.
Inverse RoPE applied after FMHA.
"""
dev = x_normed . device
T = x_normed . shape [ 0 ]
n_h = cfg [ " num_attention_heads " ]
@@ -414,19 +385,22 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
positions = positions . to ( rope_cos . device )
# 1. Q projection: q_a → q_a_norm → q_b → q_b_norm
q_a = do_nvfp4_linear ( x_normed , w , pfx , ' q_a_proj ' )
q_a = prod_lin [ ' q_a ' ] ( x_normed ) if prod_lin and ' q_a ' in prod_lin else \
do_nvfp4_linear_ref ( x_normed , w , pfx , ' q_a_proj ' )
if q_a is None :
log . warning ( f " L { li } : q_a_proj not found " )
return torch . zeros ( T , cfg [ " hidden_size " ] , dtype = torch . bfloat16 , device = dev ) , None
q_norm_w = w . get ( f " { pfx } .q_a_norm.weight " )
if q_norm_w is not None : q_a = rmsnorm ( q_a , q_norm_w . to ( dev , torch . float32 ) )
q = do_nvfp4_linear ( q_a , w , pfx , ' q_b_proj ' )
q = prod_lin [ ' q_b ' ] ( q_a ) if prod_lin and ' q_b ' in prod_lin else \
do_nvfp4_linear_ref ( q_a , w , pfx , ' q_b_proj ' )
q = unweighted_rmsnorm ( q ) . bfloat16 ( )
q_heads = q . reshape ( T , n_h , hd )
q_heads = _apply_rope ( q_heads , positions , rope_cos , rope_sin , rd )
# 2. KV projection (MQA, single KV head, hd dim)
kv = do_nvfp4_linear ( x_normed , w , pfx , ' kv_proj ' )
kv = prod_lin [ ' kv ' ] ( x_normed ) if prod_lin and ' kv ' in prod_lin else \
do_nvfp4_linear_ref ( x_normed , w , pfx , ' kv_proj ' )
if kv is None :
log . warning ( f " L { li } : kv_proj not found " )
return torch . zeros ( T , cfg [ " hidden_size " ] , dtype = torch . bfloat16 , device = dev ) , q_a
@@ -473,13 +447,13 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
if seq_len == 0 :
return torch . zeros ( T , cfg [ " hidden_size " ] , dtype = torch . bfloat16 , device = dev ) , q_a
# 6. Production FMHA kernel (6-warp TMA multi-tile)
# 6. Production FMHA kernel (6-warp TMA multi-tile) with sink bias
attn_out = _run_production_fmha ( q_heads , all_kv , n_h , hd , T , seq_len , scale , dev , li , w , pfx )
# 7. Inverse RoPE (FP32 cache)
attn_out = _apply_rope ( attn_out , positions , rope_cos , rope_sin , rd , inverse = True )
# 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4)
# 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4 GEMM )
hpg = n_h / / o_groups
gid = hpg * hd
oa_w = w . get ( f " { pfx } .o_a_proj.weight " )
@@ -490,108 +464,25 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
oa_3d = oa_bf . reshape ( o_groups , o_rank , gid )
g_out = torch . bmm ( a_grp . permute ( 1 , 0 , 2 ) , oa_3d . transpose ( 1 , 2 ) )
g_flat = g_out . permute ( 1 , 0 , 2 ) . reshape ( T , o_groups * o_rank )
F_attn = do_nvfp4_linear ( g_flat , w , pfx , ' o_b_proj ' )
F_attn = prod_lin [ ' o_b ' ] ( g_flat ) if prod_lin and ' o_b ' in prod_lin else \
do_nvfp4_linear_ref ( g_flat , w , pfx , ' o_b_proj ' )
else :
F_attn = do_nvfp4_linear ( attn_out . reshape ( T , n_h * hd ) , w , pfx , ' o_a_proj ' )
F_attn = prod_lin [ ' o_a ' ] ( attn_out . reshape ( T , n_h * hd ) ) if prod_lin and ' o_a ' in prod_lin else \
do_nvfp4_linear_ref ( attn_out . reshape ( T , n_h * hd ) , w , pfx , ' o_a_proj ' )
return F_attn , q_a
# =====================================================================
# MoE forward — uses production Nvfp4MoE + Nvfp4SharedExpert kernels
# MoE forward — production Nvfp4MoE + Nvfp4SharedExpert + Router
# =====================================================================
def moe_forward ( x , w , li , cfg , token_id , device , moe_runner , se_runner , router ) :
def moe_forward ( x , li , moe_runner , se_runner , router , token_id ):
""" MoE forward using production NVFP4 GEMM kernels.
Router uses production dense/hash router kernels.
Expert GEMMs use CuTeDSL NVFP4 grouped GEMM (fused SwiGLU).
Shared expert uses CuTeDSL NVFP4 single-group GEMM.
No F.linear. No BF16 matmul. No PyTorch loops over experts.
NO fallback to reference. Production kernels ONLY .
"""
H = cfg [ " hidden_size " ]
n_e = cfg [ " n_routed_experts " ]
top_k = cfg . get ( " num_experts_per_tok " , 6 )
rsc = cfg . get ( " routed_scaling_factor " , 2.5 )
lim = cfg . get ( " swiglu_limit " , 10.0 )
num_hash = cfg . get ( " num_hash_layers " , 3 )
pfx = f " model.layers. { li } .mlp "
# Production router: returns (topk_weights, topk_ids) via kernel
if router is not None :
try :
topk_w , topk_ids = router ( x , token_ids = token_id )
# Production MoE kernel: NVFP4 grouped GEMM with fused SwiGLU
routed_out = moe_runner ( x , topk_w , topk_ids )
# Production shared expert: NVFP4 single-group GEMM
shared_out = se_runner ( x )
return routed_out + shared_out
except Exception as e :
log . warning ( f " L { li } : Production MoE failed ( { e } ), falling back to reference " )
# Fall through to reference path
# Reference fallback (only if production kernels fail)
return _moe_forward_reference ( x , w , li , cfg , token_id , device )
def _moe_forward_reference ( x , w , li , cfg , token_id , device ) :
""" Reference MoE using dequantized BF16 weights. """
H = cfg [ " hidden_size " ]
n_e = cfg [ " n_routed_experts " ]
top_k = cfg . get ( " num_experts_per_tok " , 6 )
rsc = cfg . get ( " routed_scaling_factor " , 2.5 )
lim = cfg . get ( " swiglu_limit " , 10.0 )
num_hash = cfg . get ( " num_hash_layers " , 3 )
pfx = f " model.layers. { li } .mlp "
tid2eid_key = f " { pfx } .gate.tid2eid "
e_bias_key = f " { pfx } .gate.e_score_correction_bias "
is_hash = ( li < num_hash ) and ( tid2eid_key in w )
if is_hash :
tid2eid = w [ tid2eid_key ]
tid = token_id . item ( ) if token_id . numel ( ) == 1 else token_id [ 0 ] . item ( )
expert_ids = tid2eid [ tid ]
expert_weights = torch . ones ( top_k , dtype = torch . float32 , device = x . device ) / top_k
else :
gate_ww , gate_ws , gate_ws2 , gate_isc = get_nvfp4_weight ( w , pfx , ' gate ' )
if gate_ww is not None and gate_ws is not None :
logits = nvfp4_linear ( x , gate_ww . to ( device ) , gate_ws . to ( device ) ,
gate_ws2 . to ( device ) if gate_ws2 is not None else None ,
gate_isc . to ( device ) if gate_isc is not None else None )
elif f " { pfx } .gate.weight " in w :
gw = w [ f " { pfx } .gate.weight " ] . bfloat16 ( ) . to ( device )
logits = F . linear ( x , gw )
else :
raise ValueError ( f " No gate weight for layer { li } " )
scores = torch . sqrt ( F . softplus ( logits . float ( ) ) + 1e-6 )
sel = scores . clone ( )
if e_bias_key in w :
sel = sel + w [ e_bias_key ] . to ( device = x . device ) . float ( ) . unsqueeze ( 0 )
_ , indices = sel . topk ( top_k , - 1 )
expert_weights = torch . gather ( scores , - 1 , indices )
expert_weights = expert_weights / expert_weights . sum ( - 1 , keepdim = True )
expert_ids , expert_weights = indices [ 0 ] , expert_weights [ 0 ]
expert_outs = [ ]
for i , eid in enumerate ( expert_ids ) :
ep = f " { pfx } .experts. { eid } "
g = do_nvfp4_linear ( x , w , ep , ' gate_proj ' )
u = do_nvfp4_linear ( x , w , ep , ' up_proj ' )
silu = F . silu ( g . float ( ) )
if lim is not None : silu = silu . clamp ( - lim , lim ) ; u = u . float ( ) . clamp ( - lim , lim )
h = ( silu * u ) . bfloat16 ( )
expert_outs . append ( do_nvfp4_linear ( h , w , ep , ' down_proj ' ) )
routed = torch . zeros_like ( x )
for out , wt in zip ( expert_outs , expert_weights ) :
routed = routed + ( out . float ( ) * wt . item ( ) ) . bfloat16 ( )
routed = ( routed . float ( ) * rsc ) . bfloat16 ( )
sp = f " { pfx } .shared_experts "
sg = do_nvfp4_linear ( x , w , sp , ' gate_proj ' )
su = do_nvfp4_linear ( x , w , sp , ' up_proj ' )
silu = F . silu ( sg . float ( ) )
if lim is not None : silu = silu . clamp ( - lim , lim ) ; su = su . float ( ) . clamp ( - lim , lim )
shared = do_nvfp4_linear ( ( silu * su ) . bfloat16 ( ) , w , sp , ' down_proj ' )
return routed + shared
topk_w , topk_ids = router ( x , token_ids = token_id )
routed_out = moe_runner ( x , topk_w , topk_ids )
shared_out = se_runner ( x )
return routed_out + shared_out
# =====================================================================
# Layer forward
@@ -600,18 +491,20 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
attn_mhc , ffn_mhc , attn_norm_w , ffn_norm_w ,
kv_cache , positions , token_id ,
compressor = None , indexer = None ,
moe_runner = None , se_runner = None , router = None ) :
moe_runner = None , se_runner = None , router = None ,
prod_lin = None ) :
dev = X_l . device
# Attention sub-block
x_in , ctx_a = attn_mhc . pre_block ( X_l )
x_normed = rmsnorm ( x_in , attn_norm_w )
F_attn , _ = forward_attention ( x_normed , w , li , cfg , rope_cos , rope_sin ,
kv_cache , positions , compressor , indexer )
kv_cache , positions , compressor , indexer ,
prod_lin = prod_lin )
X_mid = attn_mhc . post_block ( X_l , F_attn , ctx_a )
# FFN sub-block
x_in_f , ctx_f = ffn_mhc . pre_block ( X_mid )
x_ffn = rmsnorm ( x_in_f , ffn_norm_w )
F_ffn = moe_forward ( x_ffn , w , li , cfg , token_id , dev , moe_runner , se_runner , router )
F_ffn = moe_forward ( x_ffn , li , moe_runner , se_runner , router , token_id )
X_next = ffn_mhc . post_block ( X_mid , F_ffn , ctx_f )
if VERBOSE > = 1 :
print ( f " L { li } : |X|= { X_l . abs ( ) . max ( ) . item ( ) : .1f } → { X_next . abs ( ) . max ( ) . item ( ) : .1f } "
@@ -619,15 +512,132 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
return X_next
# =====================================================================
# Main
# MoE weight loading (stacked path for production GEMM)
# =====================================================================
def _load_moe_weights_stacked ( all_w , li , pfx , dev , moe , cfg ) :
n_e = cfg [ " n_routed_experts " ]
w0 = all_w . get ( f " { pfx } .experts.0.gate_proj.weight " )
if w0 is None :
log . warning ( f " L { li } : No expert weights found " )
return
gate_N , gate_K = w0 . shape
l1_stacked = torch . zeros ( n_e , 2 * gate_N , gate_K , dtype = w0 . dtype )
l1_sf_stacked = None
l2_stacked = None
l2_sf_stacked = None
l1_gs = [ ]
l2_gs = [ ]
ws0 = all_w . get ( f " { pfx } .experts.0.gate_proj.weight_scale " )
if ws0 is not None :
sf_N , sf_K = ws0 . shape
l1_sf_stacked = torch . zeros ( n_e , 2 * sf_N , sf_K , dtype = ws0 . dtype )
dw0 = all_w . get ( f " { pfx } .experts.0.down_proj.weight " )
if dw0 is not None :
down_N , down_K = dw0 . shape
l2_stacked = torch . zeros ( n_e , down_N , down_K , dtype = dw0 . dtype )
dws0 = all_w . get ( f " { pfx } .experts.0.down_proj.weight_scale " )
if dws0 is not None :
l2_sf_stacked = torch . zeros ( n_e , dws0 . shape [ 0 ] , dws0 . shape [ 1 ] , dtype = dws0 . dtype )
for eid in range ( n_e ) :
gw = all_w . get ( f " { pfx } .experts. { eid } .gate_proj.weight " )
gws = all_w . get ( f " { pfx } .experts. { eid } .gate_proj.weight_scale " )
gisc = all_w . get ( f " { pfx } .experts. { eid } .gate_proj.input_scale " )
uw = all_w . get ( f " { pfx } .experts. { eid } .up_proj.weight " )
uws = all_w . get ( f " { pfx } .experts. { eid } .up_proj.weight_scale " )
if gw is not None and uw is not None :
l1_stacked [ eid , : gate_N ] = gw
l1_stacked [ eid , gate_N : ] = uw
if gws is not None and uws is not None and l1_sf_stacked is not None :
l1_sf_stacked [ eid , : sf_N ] = gws
l1_sf_stacked [ eid , sf_N : ] = uws
l1_gs . append ( gisc . float ( ) . item ( ) if gisc is not None else 1.0 / ( 6.0 * 448.0 ) )
dw = all_w . get ( f " { pfx } .experts. { eid } .down_proj.weight " )
dws = all_w . get ( f " { pfx } .experts. { eid } .down_proj.weight_scale " )
disc = all_w . get ( f " { pfx } .experts. { eid } .down_proj.input_scale " )
if dw is not None :
l2_stacked [ eid ] = dw
if dws is not None and l2_sf_stacked is not None :
l2_sf_stacked [ eid ] = dws
l2_gs . append ( disc . float ( ) . item ( ) if disc is not None else 1.0 / ( 6.0 * 448.0 ) )
l1_stacked = l1_stacked . to ( dev )
l1_sf_stacked = l1_sf_stacked . to ( dev ) if l1_sf_stacked is not None else None
l2_stacked = l2_stacked . to ( dev ) if l2_stacked is not None else None
l2_sf_stacked = l2_sf_stacked . to ( dev ) if l2_sf_stacked is not None else None
l1_gs = l1_gs if l1_gs else [ 1.0 / ( 6.0 * 448.0 ) ] * n_e
l2_gs = l2_gs if l2_gs else [ 1.0 / ( 6.0 * 448.0 ) ] * n_e
moe . prepare_weights_from_stacked ( l1_stacked , l1_sf_stacked , l1_gs ,
l2_stacked , l2_sf_stacked , l2_gs )
def _load_shared_expert_weights ( all_w , li , pfx , dev , se , cfg ) :
l1_gate_fp4 , l1_gate_sf , l1_gate_gs = [ ] , [ ] , [ ]
l1_up_fp4 , l1_up_sf = [ ] , [ ]
l2_fp4 , l2_sf , l2_gs = [ ] , [ ] , [ ]
for proj , fp4_l , sf_l , gs_l in [
( ' gate_proj ' , l1_gate_fp4 , l1_gate_sf , l1_gate_gs ) ,
( ' up_proj ' , l1_up_fp4 , l1_up_sf , None ) ,
( ' down_proj ' , l2_fp4 , l2_sf , l2_gs ) ,
] :
w , ws , isc = all_w . get ( f " { pfx } .shared_experts. { proj } .weight " ) , \
all_w . get ( f " { pfx } .shared_experts. { proj } .weight_scale " ) , \
all_w . get ( f " { pfx } .shared_experts. { proj } .input_scale " )
if w is not None and ws is not None :
fp4_l . append ( w . to ( dev ) )
sf_l . append ( ws . to ( dev ) )
if gs_l is not None :
gs_l . append ( isc . float ( ) . item ( ) if isc is not None else 1.0 / ( 6.0 * 448.0 ) )
if l1_gate_fp4 and l1_up_fp4 :
se . l1_fp4 = [ torch . cat ( [ l1_gate_fp4 [ 0 ] , l1_up_fp4 [ 0 ] ] , dim = 0 ) ]
se . l1_sf = [ torch . cat ( [ l1_gate_sf [ 0 ] , l1_up_sf [ 0 ] ] , dim = 0 ) ]
se . l1_gs = l1_gate_gs if l1_gate_gs else [ 1.0 / ( 6.0 * 448.0 ) ]
if l2_fp4 :
se . l2_fp4 = l2_fp4 ; se . l2_sf = l2_sf
se . l2_gs = l2_gs if l2_gs else [ 1.0 / ( 6.0 * 448.0 ) ]
se . finalize_weights ( )
def _cache_layer_weights_no_experts ( all_w , n_layers , devices ) :
""" Cache per-layer weights to GPUs, EXCLUDING MoE expert weights. """
cached = { }
for li in range ( n_layers ) :
dev = devices [ li % len ( devices ) ]
pfx = f " model.layers. { li } . "
w = { k : v . to ( device = dev , non_blocking = True )
for k , v in all_w . items ( )
if k . startswith ( pfx ) and ' .experts. ' not in k and ' .shared_experts. ' not in k }
cached [ li ] = w
if ( li + 1 ) % 10 == 0 : print ( f " Cached { li + 1 } / { n_layers } layers " )
return cached
def load_weights ( checkpoint_dir ) :
from safetensors . torch import load_file
cdir = Path ( checkpoint_dir )
wmap = { }
idx = cdir / " model.safetensors.index.json "
if idx . exists ( ) :
with open ( idx ) as f : wmap = json . load ( f ) . get ( " weight_map " , { } )
shards = set ( wmap . values ( ) ) if wmap else set ( )
all_w = { }
for sn in sorted ( shards ) :
if ( cdir / sn ) . exists ( ) :
all_w . update ( load_file ( str ( cdir / sn ) ) )
return all_w
def main ( ) :
t0 = time . time ( )
torch . manual_seed ( SEED )
print ( " = " * 70 )
print ( " DSV4 Single-Shot Inference — PRODUCTION KERNEL STACK " )
print ( " FMHA: 6-warp TMA multi-tile | Compressor + Indexer | mHC | MoE " )
print ( " NVFP4 GEMM (CuTeDSL) | Router kernels | NO PyTorch SDPA " )
print ( " FMHA: 6-warp TMA multi-tile + sink bias " )
print ( " NVFP4 GEMM (CuTeDSL) | Router kernels | Production mHC " )
print ( " NO PyTorch SDPA | NO dequant+matmul | NO reference fallback " )
print ( " = " * 70 )
with open ( os . path . join ( CHECKPOINT_DIR , " config.json " ) ) as f :
@@ -641,17 +651,18 @@ def main():
print ( f " Compress ratios: first5= { cr [ : 5 ] } len= { len ( cr ) } " )
print ( f " Experts: { cfg [ ' n_routed_experts ' ] } , top- { cfg . get ( ' num_experts_per_tok ' , 6 ) } " )
# Load weights
# ---- Phase 1: Load weights ----
print ( f " \n Phase 1: Loading weights... " )
all_w = load_weights ( CHECKPOINT_DIR )
all_w = load_all_ weights ( CHECKPOINT_DIR )
print ( f " { time . time ( ) - t0 : .1f } s " )
# Build production components
# ---- Phase 2: Build production components ----
print ( " Building production components... " )
from dsv4 . layers . mhc import mHCLayer
from dsv4 . layers . router import Router
from dsv4 . layers . moe import Nvfp4MoE
from dsv4 . layers . shared_expert import Nvfp4SharedExpert
from dsv4 . layers . linear import Nvfp4Linear
# mHC + norms
attn_mhcs , ffn_mhcs , attn_norms , ffn_norms = { } , { } , { } , { }
@@ -665,8 +676,20 @@ def main():
] :
fn , base , scale = all_w . get ( fn_s ) , all_w . get ( base_s ) , all_w . get ( scale_s )
if fn is not None and base is not None and scale is not None :
m = mHCBlock ( H , 4 , 20 , dev )
m . load ( fn , base , scale )
m = mHCLayer ( hidden_dim = H , n_hc = 4 , t_max_sinkhorn = 20 , device = dev )
# Split fn/base/scale into pre/post/comb
n = 4
m . load_weights (
W_pre = fn [ 0 : n ] . to ( dev , torch . float32 ) ,
W_post = fn [ n : 2 * n ] . to ( dev , torch . float32 ) ,
W_comb = fn [ 2 * n : ] . to ( dev , torch . float32 ) ,
S_pre = base [ 0 : n ] . reshape ( 1 , n ) . to ( dev , torch . float32 ) ,
S_post = base [ n : 2 * n ] . reshape ( n , 1 ) . to ( dev , torch . float32 ) ,
S_comb = base [ 2 * n : ] . reshape ( n , n ) . to ( dev , torch . float32 ) ,
alpha_pre = scale [ 0 ] . item ( ) ,
alpha_post = scale [ 1 ] . item ( ) ,
alpha_comb = scale [ 2 ] . item ( ) ,
)
blocks [ li ] = m
an_k = f " model.layers. { li } .input_layernorm.weight "
@@ -674,6 +697,27 @@ def main():
fn_k = f " model.layers. { li } .post_attention_layernorm.weight "
if fn_k in all_w : ffn_norms [ li ] = all_w [ fn_k ] . to ( dev , torch . float32 )
# Production Nvfp4Linear for attention projections
prod_lins = { }
for li in range ( n_layers ) :
dev = f " cuda: { li % NUM_GPUS } "
pfx = f " model.layers. { li } .self_attn "
plin = { }
for proj , in_f , out_f in [
( ' q_a ' , H , cfg . get ( ' query_compression_dim ' , 1536 ) ) ,
( ' q_b ' , cfg . get ( ' query_compression_dim ' , 1536 ) , n_h * hd ) ,
( ' kv ' , H , hd ) ,
( ' o_b ' , cfg . get ( ' o_groups ' , 16 ) * cfg . get ( ' o_lora_rank ' , 1024 ) , H ) ,
] :
wt , ws , ws2 , isc = get_nvfp4_weight ( all_w , pfx , proj )
if wt is not None and ws is not None :
lin = make_nvfp4_linear ( in_f , out_f , dev , wt , ws , ws2 , isc )
lin . finalize_weights ( )
plin [ proj ] = lin
if plin :
prod_lins [ li ] = plin
if ( li + 1 ) % 10 == 0 : print ( f " Built Nvfp4Linear { li + 1 } / { n_layers } layers " )
# Routers, MoE, shared experts
routers , moe_runners , se_runners = { } , { } , { }
for li in range ( n_layers ) :
@@ -681,7 +725,6 @@ def main():
pfx = f " model.layers. { li } .mlp "
is_hash = ( li < cfg . get ( " num_hash_layers " , 3 ) ) and ( f " { pfx } .gate.tid2eid " in all_w )
# Router
router = Router (
hidden_size = H , num_experts = cfg [ " n_routed_experts " ] ,
top_k = cfg . get ( " num_experts_per_tok " , 6 ) ,
@@ -700,19 +743,15 @@ def main():
router . finalize_weights ( )
routers [ li ] = router
# MoE (production NVFP4 grouped GEMM)
moe = Nvfp4MoE (
num_experts = cfg [ " n_routed_experts " ] , hidden_size = H ,
intermediate_size = cfg . get ( " moe_intermediate_size " , 3072 ) ,
top_k = cfg . get ( " num_experts_per_tok " , 6 ) , device = dev ,
)
moe . set_swiglu_limit ( cfg . get ( " swiglu_limit " , 10.0 ) )
# Load expert weights (stacked path)
_load_moe_weights_stacked ( all_w , li , pfx , dev , moe , cfg )
moe_runners [ li ] = moe
# Shared expert (production NVFP4 single-group GEMM)
se = Nvfp4SharedExpert (
hidden_size = H , intermediate_size = cfg . get ( " moe_intermediate_size " , 3072 ) ,
device = dev , swiglu_limit = cfg . get ( " swiglu_limit " , 10.0 ) ,
@@ -761,7 +800,6 @@ def main():
if ratio == 4 : indexers [ li ] = Indexer ( n_ih , ihd , itk , dev )
# Cache layer weights (EXCLUDE MoE/SE expert weights — handled by production runners)
# This avoids double-loading ~10GB/layer of expert FP4 weights
print ( " Caching layer weights to GPUs (excluding MoE expert weights)... " )
devs = [ f " cuda: { g } " for g in range ( NUM_GPUS ) ]
layer_w = _cache_layer_weights_no_experts ( all_w , n_layers , devs )
@@ -778,8 +816,8 @@ def main():
if li in indexers : indexers [ li ] . load ( layer_w [ li ] , f " { pfx } .indexer " )
print ( " Compressors/indexers loaded " )
# Phase 2 : Inference
print ( f " \n Phase 2 : Inference " )
# ---- Phase 3 : Inference ----
print ( f " \n Phase 3 : Inference " )
from transformers import AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( CHECKPOINT_DIR )
@@ -796,7 +834,7 @@ def main():
t1 = time . time ( )
tid = torch . tensor ( [ tid_val ] , dtype = torch . long , device = ' cuda:0 ' )
pos = torch . tensor ( [ pi ] , dtype = torch . long , device = ' cuda:0 ' )
X = mHCBlock . init_state ( embed ( tid ) )
X = mHCLayer . init_state ( embed ( tid ) )
for li in range ( n_layers ) :
gpu = li % NUM_GPUS
if X . device != torch . device ( f " cuda: { gpu } " ) : X = X . to ( f " cuda: { gpu } " )
@@ -806,7 +844,8 @@ def main():
attn_norms . get ( li ) , ffn_norms . get ( li ) ,
kv_caches [ li ] , pos , tid ,
compressors . get ( li ) , indexers . get ( li ) ,
moe_runners . get ( li ) , se_runners . get ( li ) , routers . get ( li ) )
moe_runners . get ( li ) , se_runners . get ( li ) , routers . get ( li ) ,
prod_lin = prod_lins . get ( li ) )
X = X . to ( ' cuda:0 ' ) ; torch . cuda . set_device ( 0 )
if pi % 10 == 0 : print ( f " Token { pi } / { len ( generated ) } : { time . time ( ) - t1 : .2f } s " , flush = True )
print ( f " Prefill done ( { time . time ( ) - t0 : .1f } s) " )
@@ -822,7 +861,7 @@ def main():
t1 = time . time ( )
tid = torch . tensor ( [ all_tokens [ - 1 ] ] , dtype = torch . long , device = ' cuda:0 ' )
dec_pos = torch . tensor ( [ len ( all_tokens ) - 1 ] , dtype = torch . long , device = ' cuda:0 ' )
X = mHCBlock . init_state ( embed ( tid ) )
X = mHCLayer . init_state ( embed ( tid ) )
for li in range ( n_layers ) :
gpu = li % NUM_GPUS
if X . device != torch . device ( f " cuda: { gpu } " ) : X = X . to ( f " cuda: { gpu } " )
@@ -832,7 +871,8 @@ def main():
attn_norms . get ( li ) , ffn_norms . get ( li ) ,
kv_caches [ li ] , dec_pos , tid ,
compressors . get ( li ) , indexers . get ( li ) ,
moe_runners . get ( li ) , se_runners . get ( li ) , routers . get ( li ) )
moe_runners . get ( li ) , se_runners . get ( li ) , routers . get ( li ) ,
prod_lin = prod_lins . get ( li ) )
X = X . to ( ' cuda:0 ' ) ; torch . cuda . set_device ( 0 )
x_out = hc_head . forward ( X ) if hc_head is not None else X [ : , 0 , : ]
if final_norm_w is not None : x_out = rmsnorm ( x_out , final_norm_w )
@@ -858,157 +898,6 @@ def main():
print ( f " Total: { time . time ( ) - t0 : .1f } s " )
print ( f " { ' = ' * 70 } " )
# =====================================================================
# MoE weight loading helpers (stacked path for production GEMM)
# =====================================================================
def _load_moe_weights_stacked ( all_w , li , pfx , dev , moe , cfg ) :
""" Load MoE expert weights into Nvfp4MoE via stacked path.
Memory-efficient: builds stacked tensors incrementally on CPU,
then moves to GPU in one shot. Avoids holding 384 individual
expert weight tensors on GPU simultaneously (~3× memory savings).
"""
n_e = cfg [ " n_routed_experts " ]
moe_inter = cfg . get ( " moe_intermediate_size " , 3072 )
H = cfg [ " hidden_size " ]
# Build stacked tensors incrementally on CPU
# gate_proj and up_proj: (inter, K_packed) per expert → L1 stacked (E, 2*inter, K_packed)
# down_proj: (H, K_packed) per expert → L2 stacked (E, H, K_packed)
# Get dimensions from first expert
w0 = all_w . get ( f " { pfx } .experts.0.gate_proj.weight " )
if w0 is None :
log . warning ( f " L { li } : No expert weights found " )
return
gate_N , gate_K = w0 . shape # (inter, K_packed)
l1_stacked = torch . zeros ( n_e , 2 * gate_N , gate_K , dtype = w0 . dtype )
l1_sf_stacked = None
l2_stacked = None
l2_sf_stacked = None
l1_gs = [ ]
l2_gs = [ ]
# Determine L1 SF shape from first expert
ws0 = all_w . get ( f " { pfx } .experts.0.gate_proj.weight_scale " )
if ws0 is not None :
sf_N , sf_K = ws0 . shape
l1_sf_stacked = torch . zeros ( n_e , 2 * sf_N , sf_K , dtype = ws0 . dtype )
# Get L2 shape
dw0 = all_w . get ( f " { pfx } .experts.0.down_proj.weight " )
if dw0 is not None :
down_N , down_K = dw0 . shape
l2_stacked = torch . zeros ( n_e , down_N , down_K , dtype = dw0 . dtype )
dws0 = all_w . get ( f " { pfx } .experts.0.down_proj.weight_scale " )
if dws0 is not None :
dsf_N , dsf_K = dws0 . shape
l2_sf_stacked = torch . zeros ( n_e , dsf_N , dsf_K , dtype = dws0 . dtype )
# Fill stacked tensors
for eid in range ( n_e ) :
# L1: gate + up
gw = all_w . get ( f " { pfx } .experts. { eid } .gate_proj.weight " )
gws = all_w . get ( f " { pfx } .experts. { eid } .gate_proj.weight_scale " )
gisc = all_w . get ( f " { pfx } .experts. { eid } .gate_proj.input_scale " )
uw = all_w . get ( f " { pfx } .experts. { eid } .up_proj.weight " )
uws = all_w . get ( f " { pfx } .experts. { eid } .up_proj.weight_scale " )
if gw is not None and uw is not None :
l1_stacked [ eid , : gate_N ] = gw
l1_stacked [ eid , gate_N : ] = uw
if gws is not None and uws is not None and l1_sf_stacked is not None :
l1_sf_stacked [ eid , : sf_N ] = gws
l1_sf_stacked [ eid , sf_N : ] = uws
l1_gs . append ( gisc . float ( ) . item ( ) if gisc is not None else 1.0 / ( 6.0 * 448.0 ) )
# L2: down
dw = all_w . get ( f " { pfx } .experts. { eid } .down_proj.weight " )
dws = all_w . get ( f " { pfx } .experts. { eid } .down_proj.weight_scale " )
disc = all_w . get ( f " { pfx } .experts. { eid } .down_proj.input_scale " )
if dw is not None :
l2_stacked [ eid ] = dw
if dws is not None and l2_sf_stacked is not None :
l2_sf_stacked [ eid ] = dws
l2_gs . append ( disc . float ( ) . item ( ) if disc is not None else 1.0 / ( 6.0 * 448.0 ) )
# Move to GPU in one shot
l1_stacked = l1_stacked . to ( dev )
l1_sf_stacked = l1_sf_stacked . to ( dev ) if l1_sf_stacked is not None else None
l2_stacked = l2_stacked . to ( dev ) if l2_stacked is not None else None
l2_sf_stacked = l2_sf_stacked . to ( dev ) if l2_sf_stacked is not None else None
l1_gs = l1_gs if l1_gs else [ 1.0 / ( 6.0 * 448.0 ) ] * n_e
l2_gs = l2_gs if l2_gs else [ 1.0 / ( 6.0 * 448.0 ) ] * n_e
moe . prepare_weights_from_stacked ( l1_stacked , l1_sf_stacked , l1_gs ,
l2_stacked , l2_sf_stacked , l2_gs )
def _load_shared_expert_weights ( all_w , li , pfx , dev , se , cfg ) :
""" Load shared expert weights. """
l1_gate_fp4 , l1_gate_sf , l1_gate_gs = [ ] , [ ] , [ ]
l1_up_fp4 , l1_up_sf = [ ] , [ ]
l2_fp4 , l2_sf , l2_gs = [ ] , [ ] , [ ]
for proj , fp4_l , sf_l , gs_l in [
( ' gate_proj ' , l1_gate_fp4 , l1_gate_sf , l1_gate_gs ) ,
( ' up_proj ' , l1_up_fp4 , l1_up_sf , None ) ,
( ' down_proj ' , l2_fp4 , l2_sf , l2_gs ) ,
] :
w_k = f " { pfx } .shared_experts. { proj } .weight "
ws_k = f " { pfx } .shared_experts. { proj } .weight_scale "
isc_k = f " { pfx } .shared_experts. { proj } .input_scale "
w , ws , isc = all_w . get ( w_k ) , all_w . get ( ws_k ) , all_w . get ( isc_k )
if w is not None and ws is not None :
fp4_l . append ( w . to ( dev ) )
sf_l . append ( ws . to ( dev ) )
if gs_l is not None :
gs_l . append ( isc . float ( ) . item ( ) if isc is not None else 1.0 / ( 6.0 * 448.0 ) )
if l1_gate_fp4 and l1_up_fp4 :
se . l1_fp4 = [ torch . cat ( [ l1_gate_fp4 [ 0 ] , l1_up_fp4 [ 0 ] ] , dim = 0 ) ]
se . l1_sf = [ torch . cat ( [ l1_gate_sf [ 0 ] , l1_up_sf [ 0 ] ] , dim = 0 ) ]
se . l1_gs = l1_gate_gs if l1_gate_gs else [ 1.0 / ( 6.0 * 448.0 ) ]
if l2_fp4 :
se . l2_fp4 = l2_fp4 ; se . l2_sf = l2_sf
se . l2_gs = l2_gs if l2_gs else [ 1.0 / ( 6.0 * 448.0 ) ]
se . finalize_weights ( )
def _cache_layer_weights_no_experts ( all_w , n_layers , devices ) :
""" Cache per-layer weights to GPUs, EXCLUDING MoE expert weights.
MoE expert weights (model.layers. {li} .mlp.experts.*) are handled by
Nvfp4MoE runners with stacked tensors. Shared expert weights are handled
by Nvfp4SharedExpert runners. Including them here would double-load
~10.6GB/layer of FP4 expert weights.
"""
cached = { }
for li in range ( n_layers ) :
dev = devices [ li % len ( devices ) ]
pfx = f " model.layers. { li } . "
w = { k : v . to ( device = dev , non_blocking = True )
for k , v in all_w . items ( )
if k . startswith ( pfx ) and ' .experts. ' not in k and ' .shared_experts. ' not in k }
cached [ li ] = w
if ( li + 1 ) % 10 == 0 : print ( f " Cached { li + 1 } / { n_layers } layers " )
return cached
def load_weights ( checkpoint_dir ) :
from safetensors . torch import load_file
cdir = Path ( checkpoint_dir )
wmap = { }
idx = cdir / " model.safetensors.index.json "
if idx . exists ( ) :
with open ( idx ) as f : wmap = json . load ( f ) . get ( " weight_map " , { } )
shards = set ( wmap . values ( ) ) if wmap else set ( )
all_w = { }
for sn in sorted ( shards ) :
if ( cdir / sn ) . exists ( ) :
all_w . update ( load_file ( str ( cdir / sn ) ) )
return all_w
if __name__ == " __main__ " :
main ( )
main ( )