fix: convert uint8 checkpoint weights to float4_e2m1fn_x2 for CuTeDSL GEMM

The CuTeDSL kernel expects float4_e2m1fn_x2 dtype for FP4 weight tensors,
but checkpoint weights from safetensors are loaded as uint8. The uint8 and
float4_e2m1fn_x2 have the same byte representation, so .view() is safe.

Fixed in:
- Nvfp4Linear.finalize_weights()
- Nvfp4SharedExpert.finalize_weights()
- Nvfp4MoE._ensure_stacked() (both stacked and legacy paths)
This commit is contained in:
2026-06-01 00:18:34 +00:00
parent 172448514c
commit e8a7a9256f
3 changed files with 19 additions and 5 deletions

View File

@@ -70,7 +70,9 @@ class Nvfp4Linear:
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM."""
self._mat_b = make_b_k_major(torch.stack(self.fp4)) # (1, K_packed, N_packed)
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4]
self._mat_b = make_b_k_major(torch.stack(fp4_view)) # (1, K_packed, N_packed)
self._scale_b = assemble_scales_3d_side(self.sf)
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)

View File

@@ -210,6 +210,11 @@ class Nvfp4MoE:
# This pairs gate/up within the MMA accumulator, enabling
# fused SwiGLU without runtime conditionals.
l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn)
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
if l1_fp4_ekn.dtype == torch.uint8:
l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2)
if l2_fp4_ekn.dtype == torch.uint8:
l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2)
# Free stacked checkpoints before make_b_k_major (saves one copy)
self.l1_fp4_stacked = None
self.l2_fp4_stacked = None
@@ -253,8 +258,13 @@ class Nvfp4MoE:
# Legacy path: per-expert lists
l1_stacked = torch.stack(self.l1_fp4) # (E, K, N)
l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up
if l1_stacked.dtype == torch.uint8:
l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2)
l2_stacked = torch.stack(self.l2_fp4)
if l2_stacked.dtype == torch.uint8:
l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2)
self._l1_mat_b = make_b_k_major(l1_stacked)
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l2_mat_b = make_b_k_major(l2_stacked)
# Interleave L1 SF to match weight interleave
# SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N,
# then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side.

View File

@@ -102,10 +102,12 @@ class Nvfp4SharedExpert:
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
l1_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l1_fp4]
l2_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l2_fp4]
# Stack weights and convert to K-major
# l1_fp4/l2_fp4 are lists with 1 element (the shared expert)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) # (1, K_packed, N_packed)
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_mat_b = make_b_k_major(torch.stack(l1_view)) # (1, K_packed, N_packed)
self._l2_mat_b = make_b_k_major(torch.stack(l2_view))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) # (1, N, K_sf_padded)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)