diff --git a/tests/unit/test_ptx_approaches.py b/tests/unit/test_ptx_approaches.py new file mode 100644 index 00000000..7017d78e --- /dev/null +++ b/tests/unit/test_ptx_approaches.py @@ -0,0 +1,87 @@ +"""Test: try nvvm.inline_ptx with multi-line PTX block like the tutorial.""" +import torch +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cutlass_dsl import dsl_user_op, T +from cutlass._mlir.dialects import nvvm +from cutlass.cute.typing import Float32, Int32 + + +# Approach 1: Simple single-line PTX (what we want) +@dsl_user_op +def f32_to_i32_simple(x: Float32, *, loc=None, ip=None) -> Int32: + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="cvt.rni.s32.f32 $0, $1;", + loc=loc, + ip=ip, + ) + return Int32(result) + + +# Approach 2: Multi-line PTX block (tutorial pattern) +@dsl_user_op +def f32_to_i32_multiline(x: Float32, *, loc=None, ip=None) -> Int32: + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="{\n\tcvt.rni.s32.f32 $0, $1;\n\t}", + loc=loc, + ip=ip, + ) + return Int32(result) + + +# Approach 3: Using arith.fptosi directly through MLIR +from cutlass._mlir.dialects import arith + +@dsl_user_op +def f32_to_i32_arith(x: Float32, *, loc=None, ip=None) -> Int32: + return Int32( + arith.fptosi(T.i32(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + + +@cute.kernel +def test_simple(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == Int32(0): + x = cute.arch.load(inp.iterator, Float32) + r = f32_to_i32_simple(x) + cute.arch.store(out.iterator, r) + + +@cute.kernel +def test_multiline(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == Int32(0): + x = cute.arch.load(inp.iterator, Float32) + r = f32_to_i32_multiline(x) + cute.arch.store(out.iterator, r) + + +@cute.kernel +def test_arith(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == Int32(0): + x = cute.arch.load(inp.iterator, Float32) + r = f32_to_i32_arith(x) + cute.arch.store(out.iterator, r) + + +if __name__ == "__main__": + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + o = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0) + + for name, kernel in [("simple", test_simple), ("multiline", test_multiline), ("arith", test_arith)]: + print(f"\n=== Testing {name} ===") + o.zero_() + try: + compiled = cute.compile(kernel, xc, oc) + compiled(xc, oc) + print(f"{name}: Result = {o.item()} (expected 4)") + except Exception as e: + print(f"{name} FAILED: {e}")