test: compare nvvm.inline_ptx approaches + arith.fptosi

This commit is contained in:
2026-05-28 04:46:06 +00:00
parent eebf33b97d
commit 136a89f4e3

View File

@@ -0,0 +1,87 @@
"""Test: try nvvm.inline_ptx with multi-line PTX block like the tutorial."""
import torch
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.typing import Float32, Int32
# Approach 1: Simple single-line PTX (what we want)
@dsl_user_op
def f32_to_i32_simple(x: Float32, *, loc=None, ip=None) -> Int32:
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rni.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
# Approach 2: Multi-line PTX block (tutorial pattern)
@dsl_user_op
def f32_to_i32_multiline(x: Float32, *, loc=None, ip=None) -> Int32:
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
loc=loc,
ip=ip,
)
return Int32(result)
# Approach 3: Using arith.fptosi directly through MLIR
from cutlass._mlir.dialects import arith
@dsl_user_op
def f32_to_i32_arith(x: Float32, *, loc=None, ip=None) -> Int32:
return Int32(
arith.fptosi(T.i32(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
)
@cute.kernel
def test_simple(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == Int32(0):
x = cute.arch.load(inp.iterator, Float32)
r = f32_to_i32_simple(x)
cute.arch.store(out.iterator, r)
@cute.kernel
def test_multiline(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == Int32(0):
x = cute.arch.load(inp.iterator, Float32)
r = f32_to_i32_multiline(x)
cute.arch.store(out.iterator, r)
@cute.kernel
def test_arith(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == Int32(0):
x = cute.arch.load(inp.iterator, Float32)
r = f32_to_i32_arith(x)
cute.arch.store(out.iterator, r)
if __name__ == "__main__":
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
o = torch.zeros(1, dtype=torch.int32, device='cuda')
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
for name, kernel in [("simple", test_simple), ("multiline", test_multiline), ("arith", test_arith)]:
print(f"\n=== Testing {name} ===")
o.zero_()
try:
compiled = cute.compile(kernel, xc, oc)
compiled(xc, oc)
print(f"{name}: Result = {o.item()} (expected 4)")
except Exception as e:
print(f"{name} FAILED: {e}")