diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index bca330c2..e323fd85 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -59,7 +59,7 @@ def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32: @cute.jit - def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: +def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: """Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32). Only handles positive values (NVFP4 scale factors are always positive).