"""Quick test: Float32 to Int32 conversion in CuTeDSL.""" import torch import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch @cute.kernel def test_int32_cast_kernel( input_f32: cute.Tensor, output_i32: cute.Tensor, ): tidx, _, _ = cute.arch.thread_idx() if tidx == cutlass.Int32(0): f = cute.arch.load(input_f32.iterator, cutlass.Float32) i = f.to(cutlass.Int32) # float-to-int conversion cute.arch.store(output_i32.iterator, i) if __name__ == "__main__": x = torch.tensor([3.7], dtype=torch.float32, device='cuda') out = torch.zeros(1, dtype=torch.int32, device='cuda') xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0) import cuda.bindings.driver as cuda stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) print("Compiling...") compiled = cute.compile(test_int32_cast_kernel, xc, oc) print("Running...") compiled(xc, oc) print(f'Result: {out.item()} (3=trunc, 4=RNE)')