fix: convert uint8 checkpoint weights to float4_e2m1fn_x2 for CuTeDSL GEMM
The CuTeDSL kernel expects float4_e2m1fn_x2 dtype for FP4 weight tensors, but checkpoint weights from safetensors are loaded as uint8. The uint8 and float4_e2m1fn_x2 have the same byte representation, so .view() is safe. Fixed in: - Nvfp4Linear.finalize_weights() - Nvfp4SharedExpert.finalize_weights() - Nvfp4MoE._ensure_stacked() (both stacked and legacy paths)
This commit is contained in:
@@ -70,7 +70,9 @@ class Nvfp4Linear:
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process weights for CuTeDSL GEMM."""
|
||||
self._mat_b = make_b_k_major(torch.stack(self.fp4)) # (1, K_packed, N_packed)
|
||||
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
||||
fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4]
|
||||
self._mat_b = make_b_k_major(torch.stack(fp4_view)) # (1, K_packed, N_packed)
|
||||
self._scale_b = assemble_scales_3d_side(self.sf)
|
||||
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
|
||||
@@ -210,6 +210,11 @@ class Nvfp4MoE:
|
||||
# This pairs gate/up within the MMA accumulator, enabling
|
||||
# fused SwiGLU without runtime conditionals.
|
||||
l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn)
|
||||
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
||||
if l1_fp4_ekn.dtype == torch.uint8:
|
||||
l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2)
|
||||
if l2_fp4_ekn.dtype == torch.uint8:
|
||||
l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2)
|
||||
# Free stacked checkpoints before make_b_k_major (saves one copy)
|
||||
self.l1_fp4_stacked = None
|
||||
self.l2_fp4_stacked = None
|
||||
@@ -253,8 +258,13 @@ class Nvfp4MoE:
|
||||
# Legacy path: per-expert lists
|
||||
l1_stacked = torch.stack(self.l1_fp4) # (E, K, N)
|
||||
l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up
|
||||
if l1_stacked.dtype == torch.uint8:
|
||||
l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2)
|
||||
l2_stacked = torch.stack(self.l2_fp4)
|
||||
if l2_stacked.dtype == torch.uint8:
|
||||
l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2)
|
||||
self._l1_mat_b = make_b_k_major(l1_stacked)
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l2_mat_b = make_b_k_major(l2_stacked)
|
||||
# Interleave L1 SF to match weight interleave
|
||||
# SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N,
|
||||
# then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side.
|
||||
|
||||
@@ -102,10 +102,12 @@ class Nvfp4SharedExpert:
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
|
||||
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
||||
l1_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l1_fp4]
|
||||
l2_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l2_fp4]
|
||||
# Stack weights and convert to K-major
|
||||
# l1_fp4/l2_fp4 are lists with 1 element (the shared expert)
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) # (1, K_packed, N_packed)
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(l1_view)) # (1, K_packed, N_packed)
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(l2_view))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) # (1, N, K_sf_padded)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user