NVFP4-1.1: add @cute.jit decorator to fp4_quant functions for CuTeDSL if-block support
This commit is contained in:
@@ -31,6 +31,7 @@ import cutlass.cute as cute
|
||||
FP8_E4M3_BIAS = 7
|
||||
|
||||
|
||||
@cute.jit
|
||||
def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32:
|
||||
"""Map half-step value (0-12) to E2M1 index (0-7).
|
||||
|
||||
@@ -57,7 +58,8 @@ def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32:
|
||||
return result
|
||||
|
||||
|
||||
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
@cute.jit
|
||||
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
|
||||
|
||||
Only handles positive values (NVFP4 scale factors are always positive).
|
||||
@@ -136,6 +138,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
return result
|
||||
|
||||
|
||||
@cute.jit
|
||||
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
|
||||
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
|
||||
|
||||
@@ -180,6 +183,7 @@ def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
|
||||
return result
|
||||
|
||||
|
||||
@cute.jit
|
||||
def quantize_e2m1_nibble(
|
||||
val: cutlass.Float32,
|
||||
scale: cutlass.Float32,
|
||||
|
||||
Reference in New Issue
Block a user