NVFP4-1.1: Use nvvm.inline_ptx instead of llvm.inline_asm for f32→i32

llvm.inline_asm fails with 'LLVM ERROR: unsupported operation' in CuTeDSL
lowering pipeline. Switch to nvvm.inline_ptx which is native to the NVVM
dialect and lowers correctly.

- f32_to_i32_rni: cvt.rni.s32.f32 via nvvm.inline_ptx
- f32_to_i32_rz: cvt.rzi.s32.f32 via nvvm.inline_ptx
- f32_to_i32_rmi: cvt.rmi.s32.f32 via nvvm.inline_ptx
This commit is contained in:
2026-05-28 04:42:33 +00:00
parent 74dba6ab9d
commit e33c48e44c

View File

@@ -13,26 +13,34 @@ FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
- Min positive subnormal: 2^(-9) ≈ 0.001953
CuTeDSL constraints:
- float-to-int via inline PTX cvt.rni.s32.f32 (proper hardware rounding)
- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding)
Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline.
The llvm.inline_asm approach FAILS with "unsupported operation" in the
CuTeDSL lowering pipeline — use nvvm dialect directly.
- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax)
- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks
- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants
- `@dsl_user_op` + `llvm.inline_asm` for PTX instructions not exposed by CuTeDSL
- `@dsl_user_op` for PTX instruction wrappers
"""
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import llvm
from cutlass._mlir.dialects import nvvm
from cutlass.cute.typing import Float32, Int32
FP8_E4M3_BIAS = 7
# ── Inline PTX float-to-int conversion ──────────────────────────────
# ── NVVM inline PTX float-to-int conversion ─────────────────────────
# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions
# via @dsl_user_op + llvm.inline_asm. This is the PROPER approach — hardware
# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware
# rounding, no threshold hacks, no approximation.
#
# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm.
# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the
# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op
# is native to the NVVM dialect and lowers correctly.
@dsl_user_op
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
@@ -41,19 +49,14 @@ def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
Wraps PTX: cvt.rni.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rn().
"""
return Int32(
llvm.inline_asm(
T.i32(),
[Float32(x).ir_value(loc=loc, ip=ip)],
"cvt.rni.s32.f32 $0, $1;",
"=r,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rni.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@dsl_user_op
@@ -61,21 +64,16 @@ def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-toward-zero (RZ).
Wraps PTX: cvt.rzi.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rz(). Used for floor-based extraction.
Equivalent to CUDA __float2int_rz(). Used for truncation.
"""
return Int32(
llvm.inline_asm(
T.i32(),
[Float32(x).ir_value(loc=loc, ip=ip)],
"cvt.rzi.s32.f32 $0, $1;",
"=r,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rzi.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@dsl_user_op
@@ -86,19 +84,14 @@ def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32:
Equivalent to CUDA __float2int_rd() / floorf() → int.
Used for floor(log2()) extraction in FP8 encoding.
"""
return Int32(
llvm.inline_asm(
T.i32(),
[Float32(x).ir_value(loc=loc, ip=ip)],
"cvt.rmi.s32.f32 $0, $1;",
"=r,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rmi.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
@@ -281,13 +274,8 @@ def quantize_e2m1_nibble(
# LUT: half_step → E2M1 index
# [0,1,2,3,4,4,5,6,6,6,7,7]
# Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit)
idx = cutlass.Int32(0)
# We can't index an array in CuTeDSL, so implement the LUT with comparisons.
# This IS the LUT, just expressed as branch logic since CuTeDSL has no
# random-access shared arrays in @cute.jit scalar functions.
#
# The LUT is: half_step_to_e2m1[hs] for hs in 0..11
# 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→6, 8→6, 9→6, 10→7, 11→7
if half_step >= cutlass.Int32(1):
idx = cutlass.Int32(1)
if half_step >= cutlass.Int32(2):