diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 90fc3018..02eefe70 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -105,18 +105,14 @@ def transform_nvfp4_weights_for_tilelang( l2_sf_folded = l2_weight_scale.to(torch.float32) # Clamp and convert back to UE4M3 - l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - - # Pack UE4M3 into uint32 - l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped) - l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped) + l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() + l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() # Interleave L1 weights (gate_up for 2CTA UMMA) l1_weight_out = _interleave_l1_weights(l1_weight) l2_weight_out = l2_weight.contiguous() - return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed) + return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out) # Alias for drop-in replacement