NVFP4-1.1: add @cute.jit decorator to fp4_quant functions for CuTeDSL if-block support

This commit is contained in:
2026-05-28 03:50:11 +00:00
parent 0ecb98daee
commit f6f59d34cb

View File

@@ -31,6 +31,7 @@ import cutlass.cute as cute
FP8_E4M3_BIAS = 7
@cute.jit
def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32:
"""Map half-step value (0-12) to E2M1 index (0-7).
@@ -57,7 +58,8 @@ def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32:
return result
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
@cute.jit
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
Only handles positive values (NVFP4 scale factors are always positive).
@@ -136,6 +138,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
return result
@cute.jit
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
@@ -180,6 +183,7 @@ def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
return result
@cute.jit
def quantize_e2m1_nibble(
val: cutlass.Float32,
scale: cutlass.Float32,