From ebc0ab0cacde5cfd540cdcca464d68351f7a92b1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 13 May 2026 21:54:39 +0000 Subject: [PATCH] Fix: keep scales as float8_e4m3fn, don't pack to uint32 (min_all_cuda unsupported) --- src/nvfp4_megamoe_kernel/weight_transform.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 90fc3018..02eefe70 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -105,18 +105,14 @@ def transform_nvfp4_weights_for_tilelang( l2_sf_folded = l2_weight_scale.to(torch.float32) # Clamp and convert back to UE4M3 - l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - - # Pack UE4M3 into uint32 - l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped) - l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped) + l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() + l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() # Interleave L1 weights (gate_up for 2CTA UMMA) l1_weight_out = _interleave_l1_weights(l1_weight) l2_weight_out = l2_weight.contiguous() - return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed) + return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out) # Alias for drop-in replacement