From a05a76bb6bb5b8513d6c7683e7d7f495f69746c0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 03:59:01 +0000 Subject: [PATCH] NVFP4-1.1: add Int32 cast diagnostic test --- tests/unit/test_int32_cast.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/unit/test_int32_cast.py diff --git a/tests/unit/test_int32_cast.py b/tests/unit/test_int32_cast.py new file mode 100644 index 00000000..a2d5a73b --- /dev/null +++ b/tests/unit/test_int32_cast.py @@ -0,0 +1,31 @@ +"""Quick test: Float32 to Int32 conversion in CuTeDSL.""" +import torch +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch + + +@cute.kernel +def test_int32_cast_kernel( + input_f32: cute.Tensor, + output_i32: cute.Tensor, +): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + f = cute.arch.load(input_f32.iterator, cutlass.Float32) + i = cutlass.Int32(f) # float-to-int conversion + cute.arch.store(output_i32.iterator, i, cutlass.Int32) + + +if __name__ == "__main__": + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + out = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0) + import cuda.bindings.driver as cuda + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + print("Compiling...") + compiled = cute.compile(test_int32_cast_kernel, xc, oc) + print("Running...") + compiled(xc, oc) + print(f'Result: {out.item()} (3=trunc, 4=RNE)')