"""Test: try nvvm.inline_ptx with extra debug info.""" import os os.environ['CUTLASS_LOG_LEVEL'] = 'DEBUG' import torch import cutlass.cute as cute import cutlass.torch as cutlass_torch from cutlass.cutlass_dsl import dsl_user_op, T from cutlass._mlir.dialects import nvvm from cutlass.cute.typing import Float32, Int32 @dsl_user_op def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32: result = nvvm.inline_ptx( write_only_args=[T.i32()], read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], ptx_code="cvt.rni.s32.f32 $0, $1;", loc=loc, ip=ip, ) return Int32(result) @cute.kernel def minimal_test_kernel( input_f32: cute.Tensor, output_i32: cute.Tensor, ): tidx, _, _ = cute.arch.thread_idx() if tidx == cutlass.Int32(0): x = cute.arch.load(input_f32.iterator, cutlass.Float32) result = f32_to_i32_rni(x) cute.arch.store(output_i32.iterator, result) if __name__ == "__main__": x = torch.tensor([3.7], dtype=torch.float32, device='cuda') out = torch.zeros(1, dtype=torch.int32, device='cuda') xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0) print("Compiling...") compiled = cute.compile(minimal_test_kernel, xc, oc) print("Running...") compiled(xc, oc) print(f'f32_to_i32_rni(3.7) = {out.item()} (expected 4)')