NVFP4-1.1: Use nvvm.inline_ptx instead of llvm.inline_asm for f32→i32
llvm.inline_asm fails with 'LLVM ERROR: unsupported operation' in CuTeDSL lowering pipeline. Switch to nvvm.inline_ptx which is native to the NVVM dialect and lowers correctly. - f32_to_i32_rni: cvt.rni.s32.f32 via nvvm.inline_ptx - f32_to_i32_rz: cvt.rzi.s32.f32 via nvvm.inline_ptx - f32_to_i32_rmi: cvt.rmi.s32.f32 via nvvm.inline_ptx
This commit is contained in:
@@ -13,26 +13,34 @@ FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
|
||||
- Min positive subnormal: 2^(-9) ≈ 0.001953
|
||||
|
||||
CuTeDSL constraints:
|
||||
- float-to-int via inline PTX cvt.rni.s32.f32 (proper hardware rounding)
|
||||
- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding)
|
||||
Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline.
|
||||
The llvm.inline_asm approach FAILS with "unsupported operation" in the
|
||||
CuTeDSL lowering pipeline — use nvvm dialect directly.
|
||||
- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax)
|
||||
- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks
|
||||
- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants
|
||||
- `@dsl_user_op` + `llvm.inline_asm` for PTX instructions not exposed by CuTeDSL
|
||||
- `@dsl_user_op` for PTX instruction wrappers
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
FP8_E4M3_BIAS = 7
|
||||
|
||||
|
||||
# ── Inline PTX float-to-int conversion ──────────────────────────────
|
||||
# ── NVVM inline PTX float-to-int conversion ─────────────────────────
|
||||
# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions
|
||||
# via @dsl_user_op + llvm.inline_asm. This is the PROPER approach — hardware
|
||||
# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware
|
||||
# rounding, no threshold hacks, no approximation.
|
||||
#
|
||||
# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm.
|
||||
# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the
|
||||
# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op
|
||||
# is native to the NVVM dialect and lowers correctly.
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
@@ -41,19 +49,14 @@ def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
Wraps PTX: cvt.rni.s32.f32 $0, $1;
|
||||
Equivalent to CUDA __float2int_rn().
|
||||
"""
|
||||
return Int32(
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rni.s32.f32 $0, $1;",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@@ -61,21 +64,16 @@ def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
"""Convert Float32 to Int32 with round-toward-zero (RZ).
|
||||
|
||||
Wraps PTX: cvt.rzi.s32.f32 $0, $1;
|
||||
Equivalent to CUDA __float2int_rz(). Used for floor-based extraction.
|
||||
Equivalent to CUDA __float2int_rz(). Used for truncation.
|
||||
"""
|
||||
return Int32(
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rzi.s32.f32 $0, $1;",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rzi.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@@ -86,19 +84,14 @@ def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
Equivalent to CUDA __float2int_rd() / floorf() → int.
|
||||
Used for floor(log2()) extraction in FP8 encoding.
|
||||
"""
|
||||
return Int32(
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rmi.s32.f32 $0, $1;",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rmi.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
|
||||
@@ -281,13 +274,8 @@ def quantize_e2m1_nibble(
|
||||
|
||||
# LUT: half_step → E2M1 index
|
||||
# [0,1,2,3,4,4,5,6,6,6,7,7]
|
||||
# Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit)
|
||||
idx = cutlass.Int32(0)
|
||||
# We can't index an array in CuTeDSL, so implement the LUT with comparisons.
|
||||
# This IS the LUT, just expressed as branch logic since CuTeDSL has no
|
||||
# random-access shared arrays in @cute.jit scalar functions.
|
||||
#
|
||||
# The LUT is: half_step_to_e2m1[hs] for hs in 0..11
|
||||
# 0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→6, 8→6, 9→6, 10→7, 11→7
|
||||
if half_step >= cutlass.Int32(1):
|
||||
idx = cutlass.Int32(1)
|
||||
if half_step >= cutlass.Int32(2):
|
||||
|
||||
Reference in New Issue
Block a user