Fix: keep scales as float8_e4m3fn, don't pack to uint32 (min_all_cuda unsupported)

This commit is contained in:
2026-05-13 21:54:39 +00:00
parent 94233c4dd3
commit ebc0ab0cac

View File

@@ -105,18 +105,14 @@ def transform_nvfp4_weights_for_tilelang(
l2_sf_folded = l2_weight_scale.to(torch.float32)
# Clamp and convert back to UE4M3
l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
# Pack UE4M3 into uint32
l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped)
l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped)
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
# Interleave L1 weights (gate_up for 2CTA UMMA)
l1_weight_out = _interleave_l1_weights(l1_weight)
l2_weight_out = l2_weight.contiguous()
return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed)
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
# Alias for drop-in replacement