diff --git a/tests/unit/test_ptx_debug.py b/tests/unit/test_ptx_debug.py new file mode 100644 index 00000000..5c9ca190 --- /dev/null +++ b/tests/unit/test_ptx_debug.py @@ -0,0 +1,47 @@ +"""Test: try nvvm.inline_ptx with extra debug info.""" +import os +os.environ['CUTLASS_LOG_LEVEL'] = 'DEBUG' + +import torch +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cutlass_dsl import dsl_user_op, T +from cutlass._mlir.dialects import nvvm +from cutlass.cute.typing import Float32, Int32 + + +@dsl_user_op +def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32: + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="cvt.rni.s32.f32 $0, $1;", + loc=loc, + ip=ip, + ) + return Int32(result) + + +@cute.kernel +def minimal_test_kernel( + input_f32: cute.Tensor, + output_i32: cute.Tensor, +): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(input_f32.iterator, cutlass.Float32) + result = f32_to_i32_rni(x) + cute.arch.store(output_i32.iterator, result) + + +if __name__ == "__main__": + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + out = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0) + + print("Compiling...") + compiled = cute.compile(minimal_test_kernel, xc, oc) + print("Running...") + compiled(xc, oc) + print(f'f32_to_i32_rni(3.7) = {out.item()} (expected 4)') diff --git a/tests/unit/test_ptx_v2.py b/tests/unit/test_ptx_v2.py new file mode 100644 index 00000000..00ba4110 --- /dev/null +++ b/tests/unit/test_ptx_v2.py @@ -0,0 +1,87 @@ +"""Test: try different approaches to nvvm.inline_ptx wrapping.""" +import torch +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cutlass_dsl import dsl_user_op, T +from cutlass._mlir.dialects import nvvm +from cutlass.cute.typing import Float32, Int32 + + +# Approach 1: Return raw MLIR value, wrap at call site +@dsl_user_op +def f32_to_i32_raw(x: Float32, *, loc=None, ip=None) -> Int32: + result = nvvm.inline_ptx( + write_only_args=[T.i32()], + read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], + ptx_code="cvt.rni.s32.f32 $0, $1;", + loc=loc, + ip=ip, + ) + # nvvm.inline_ptx returns a Value; Int32() should wrap it + return Int32(result) + + +# Approach 2: Use nvvm.inline_ptx with two outputs (matching tutorial pattern) +# Try with has_side_effects-like pattern +@dsl_user_op +def f32_to_i32_v2(x: Float32, *, loc=None, ip=None) -> Int32: + # Use the exact same pattern as the tutorial's ptx_vote_ballot_sync + return Int32( + nvvm.inline_ptx( + [T.i32()], + [Float32(x).ir_value(loc=loc, ip=ip)], + "cvt.rni.s32.f32 $0, $1;", + loc=loc, + ip=ip, + ) + ) + + +@cute.kernel +def test_kernel_v1( + input_f32: cute.Tensor, + output_i32: cute.Tensor, +): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(input_f32.iterator, cutlass.Float32) + result = f32_to_i32_raw(x) + cute.arch.store(output_i32.iterator, result) + + +@cute.kernel +def test_kernel_v2( + input_f32: cute.Tensor, + output_i32: cute.Tensor, +): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(input_f32.iterator, cutlass.Float32) + result = f32_to_i32_v2(x) + cute.arch.store(output_i32.iterator, result) + + +if __name__ == "__main__": + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + out = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0) + + print("=== Test V1 (raw result) ===") + try: + compiled = cute.compile(test_kernel_v1, xc, oc) + compiled(xc, oc) + print(f'V1: f32_to_i32(3.7) = {out.item()}') + except Exception as e: + print(f'V1 FAILED: {e}') + + out.zero_() + + print("\n=== Test V2 (list-style args) ===") + try: + compiled = cute.compile(test_kernel_v2, xc, oc) + compiled(xc, oc) + print(f'V2: f32_to_i32(3.7) = {out.item()}') + except Exception as e: + print(f'V2 FAILED: {e}')