From e8a7a9256f421d2b8044631bdb8118c1a6dcacf3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 00:18:34 +0000 Subject: [PATCH] fix: convert uint8 checkpoint weights to float4_e2m1fn_x2 for CuTeDSL GEMM The CuTeDSL kernel expects float4_e2m1fn_x2 dtype for FP4 weight tensors, but checkpoint weights from safetensors are loaded as uint8. The uint8 and float4_e2m1fn_x2 have the same byte representation, so .view() is safe. Fixed in: - Nvfp4Linear.finalize_weights() - Nvfp4SharedExpert.finalize_weights() - Nvfp4MoE._ensure_stacked() (both stacked and legacy paths) --- dsv4/layers/linear.py | 4 +++- dsv4/layers/moe.py | 12 +++++++++++- dsv4/layers/shared_expert.py | 8 +++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index e708b598..6ae4b955 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -70,7 +70,9 @@ class Nvfp4Linear: def finalize_weights(self): """Process weights for CuTeDSL GEMM.""" - self._mat_b = make_b_k_major(torch.stack(self.fp4)) # (1, K_packed, N_packed) + # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view + fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4] + self._mat_b = make_b_k_major(torch.stack(fp4_view)) # (1, K_packed, N_packed) self._scale_b = assemble_scales_3d_side(self.sf) self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device) diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 038ce14e..96502aaa 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -210,6 +210,11 @@ class Nvfp4MoE: # This pairs gate/up within the MMA accumulator, enabling # fused SwiGLU without runtime conditionals. l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn) + # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view + if l1_fp4_ekn.dtype == torch.uint8: + l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2) + if l2_fp4_ekn.dtype == torch.uint8: + l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2) # Free stacked checkpoints before make_b_k_major (saves one copy) self.l1_fp4_stacked = None self.l2_fp4_stacked = None @@ -253,8 +258,13 @@ class Nvfp4MoE: # Legacy path: per-expert lists l1_stacked = torch.stack(self.l1_fp4) # (E, K, N) l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up + if l1_stacked.dtype == torch.uint8: + l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2) + l2_stacked = torch.stack(self.l2_fp4) + if l2_stacked.dtype == torch.uint8: + l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2) self._l1_mat_b = make_b_k_major(l1_stacked) - self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) + self._l2_mat_b = make_b_k_major(l2_stacked) # Interleave L1 SF to match weight interleave # SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N, # then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side. diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index aa0fb90a..ffb8eb37 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -102,10 +102,12 @@ class Nvfp4SharedExpert: def finalize_weights(self): """Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights.""" + # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view + l1_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l1_fp4] + l2_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l2_fp4] # Stack weights and convert to K-major - # l1_fp4/l2_fp4 are lists with 1 element (the shared expert) - self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) # (1, K_packed, N_packed) - self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) + self._l1_mat_b = make_b_k_major(torch.stack(l1_view)) # (1, K_packed, N_packed) + self._l2_mat_b = make_b_k_major(torch.stack(l2_view)) self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) # (1, N, K_sf_padded) self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)