diff --git a/tests/unit/test_ptx_constraints.py b/tests/unit/test_ptx_constraints.py new file mode 100644 index 00000000..7adfa01b --- /dev/null +++ b/tests/unit/test_ptx_constraints.py @@ -0,0 +1,87 @@ +"""Test: try different constraint strings for llvm.inline_asm cvt.rni.""" +import torch +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir.dialects import llvm +from cutlass.cute.typing import Float32, Int32 +import sys + + +approach = sys.argv[1] if len(sys.argv) > 1 else "r_r" + + +# Try "=r,r" (both as general 32-bit registers) +@dsl_user_op +def f32_to_i32_r_r(x: Float32, *, loc=None, ip=None) -> Int32: + val_i32 = llvm.inline_asm( + Int32._mlir_type(), + [Float32(x).ir_value(loc=loc, ip=ip)], + "cvt.rni.s32.f32 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + loc=loc, + ip=ip, + ) + return Int32(val_i32) + + +# Try bitcast approach: treat float as int, then do integer operations +@dsl_user_op +def f32_bitcast_to_i32(x: Float32, *, loc=None, ip=None) -> Int32: + # Bitcast float to int (reinterpret bits, not convert) + val_i32 = llvm.bitcast(Int32._mlir_type(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return Int32(val_i32) + + +# Try: floor(x) via cute.floor, then bitcast - no, floor returns float +# Try: truncate via cute.arch operations +# Actually, let's try: use llvm.inline_asm with just integer registers +# The idea: bitcast float to i32, then in PTX re-interpret as float and cvt +@dsl_user_op +def f32_to_i32_via_bitcast(x: Float32, *, loc=None, ip=None) -> Int32: + # Bitcast float bits to int, then PTX mov + cvt + bits = Float32(x).ir_value(loc=loc, ip=ip) + val_i32 = llvm.inline_asm( + Int32._mlir_type(), + [bits], + "{\n\tcvt.rni.s32.f32 $0, $1;\n\t}", + "=r,f", + has_side_effects=False, + is_align_stack=False, + loc=loc, + ip=ip, + ) + return Int32(val_i32) + + +FUNCS = { + "r_r": f32_to_i32_r_r, + "bitcast": f32_bitcast_to_i32, + "via_bitcast": f32_to_i32_via_bitcast, +} + +func = FUNCS[approach] + + +@cute.kernel +def test_k(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == Int32(0): + x = cute.arch.load(inp.iterator, Float32) + r = func(x) + cute.arch.store(out.iterator, r) + + +if __name__ == "__main__": + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + o = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0) + print(f"Approach: {approach}") + print("Compiling...") + compiled = cute.compile(test_k, xc, oc) + print("Running...") + compiled(xc, oc) + print(f"Result: {o.item()}")