From f6f59d34cb1d8b39e9be83125807fc4cd9253063 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 03:50:11 +0000 Subject: [PATCH] NVFP4-1.1: add @cute.jit decorator to fp4_quant functions for CuTeDSL if-block support --- dsv4/kernels/gemm/fp4_quant.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index 59dd2d14..bca330c2 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -31,6 +31,7 @@ import cutlass.cute as cute FP8_E4M3_BIAS = 7 +@cute.jit def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32: """Map half-step value (0-12) to E2M1 index (0-7). @@ -57,7 +58,8 @@ def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32: return result -def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: +@cute.jit + def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: """Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32). Only handles positive values (NVFP4 scale factors are always positive). @@ -136,6 +138,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: return result +@cute.jit def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32: """Convert FP8 E4M3 bit pattern (in Int32) back to Float32. @@ -180,6 +183,7 @@ def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32: return result +@cute.jit def quantize_e2m1_nibble( val: cutlass.Float32, scale: cutlass.Float32,