diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index 6d4199b0..e5a1c9d2 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -13,26 +13,34 @@ FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8): - Min positive subnormal: 2^(-9) ≈ 0.001953 CuTeDSL constraints: -- float-to-int via inline PTX cvt.rni.s32.f32 (proper hardware rounding) +- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding) + Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline. + The llvm.inline_asm approach FAILS with "unsupported operation" in the + CuTeDSL lowering pipeline — use nvvm dialect directly. - `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax) - `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks - `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants -- `@dsl_user_op` + `llvm.inline_asm` for PTX instructions not exposed by CuTeDSL +- `@dsl_user_op` for PTX instruction wrappers """ import cutlass import cutlass.cute as cute from cutlass.cutlass_dsl import dsl_user_op, T -from cutlass._mlir.dialects import llvm +from cutlass._mlir.dialects import nvvm from cutlass.cute.typing import Float32, Int32 FP8_E4M3_BIAS = 7 -# ── Inline PTX float-to-int conversion ────────────────────────────── +# ── NVVM inline PTX float-to-int conversion ───────────────────────── # CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions -# via @dsl_user_op + llvm.inline_asm. This is the PROPER approach — hardware +# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware # rounding, no threshold hacks, no approximation. +# +# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm. +# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the +# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op +# is native to the NVVM dialect and lowers correctly. @dsl_user_op def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32: @@ -41,19 +49,14 @@ def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32: Wraps PTX: cvt.rni.s32.f32 $0, $1; Equivalent to CUDA __float2int_rn(). """ - return Int32( - llvm.inline_asm( - T.i32(), - [Float32(x).ir_value(loc=loc, ip=ip)], - "cvt.rni.s32.f32 $0, $1;", - "=r,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - loc=loc, - ip=ip, - ) + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="cvt.rni.s32.f32 $0, $1;", + loc=loc, + ip=ip, ) + return Int32(result) @dsl_user_op @@ -61,21 +64,16 @@ def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32: """Convert Float32 to Int32 with round-toward-zero (RZ). Wraps PTX: cvt.rzi.s32.f32 $0, $1; - Equivalent to CUDA __float2int_rz(). Used for floor-based extraction. + Equivalent to CUDA __float2int_rz(). Used for truncation. """ - return Int32( - llvm.inline_asm( - T.i32(), - [Float32(x).ir_value(loc=loc, ip=ip)], - "cvt.rzi.s32.f32 $0, $1;", - "=r,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - loc=loc, - ip=ip, - ) + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="cvt.rzi.s32.f32 $0, $1;", + loc=loc, + ip=ip, ) + return Int32(result) @dsl_user_op @@ -86,19 +84,14 @@ def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32: Equivalent to CUDA __float2int_rd() / floorf() → int. Used for floor(log2()) extraction in FP8 encoding. """ - return Int32( - llvm.inline_asm( - T.i32(), - [Float32(x).ir_value(loc=loc, ip=ip)], - "cvt.rmi.s32.f32 $0, $1;", - "=r,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - loc=loc, - ip=ip, - ) + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="cvt.rmi.s32.f32 $0, $1;", + loc=loc, + ip=ip, ) + return Int32(result) # ── FP8 E4M3 encoding ─────────────────────────────────────────────── @@ -281,13 +274,8 @@ def quantize_e2m1_nibble( # LUT: half_step → E2M1 index # [0,1,2,3,4,4,5,6,6,6,7,7] + # Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit) idx = cutlass.Int32(0) - # We can't index an array in CuTeDSL, so implement the LUT with comparisons. - # This IS the LUT, just expressed as branch logic since CuTeDSL has no - # random-access shared arrays in @cute.jit scalar functions. - # - # The LUT is: half_step_to_e2m1[hs] for hs in 0..11 - # 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→6, 8→6, 9→6, 10→7, 11→7 if half_step >= cutlass.Int32(1): idx = cutlass.Int32(1) if half_step >= cutlass.Int32(2):