Fix: keep scales as float8_e4m3fn, don't pack to uint32 (min_all_cuda unsupported)
This commit is contained in:
@@ -105,18 +105,14 @@ def transform_nvfp4_weights_for_tilelang(
|
||||
l2_sf_folded = l2_weight_scale.to(torch.float32)
|
||||
|
||||
# Clamp and convert back to UE4M3
|
||||
l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Pack UE4M3 into uint32
|
||||
l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped)
|
||||
l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped)
|
||||
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
||||
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
||||
|
||||
# Interleave L1 weights (gate_up for 2CTA UMMA)
|
||||
l1_weight_out = _interleave_l1_weights(l1_weight)
|
||||
l2_weight_out = l2_weight.contiguous()
|
||||
|
||||
return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed)
|
||||
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
|
||||
|
||||
|
||||
# Alias for drop-in replacement
|
||||
|
||||
Reference in New Issue
Block a user