From 650bcdcccfc2358b3aed0f96fc2226788d18310b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 04:57:45 +0000 Subject: [PATCH] test: f32 vs i32 GMEM store --- tests/unit/test_dtype_store.py | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/unit/test_dtype_store.py diff --git a/tests/unit/test_dtype_store.py b/tests/unit/test_dtype_store.py new file mode 100644 index 00000000..34b46914 --- /dev/null +++ b/tests/unit/test_dtype_store.py @@ -0,0 +1,38 @@ +"""Test: does Float32 GMEM store work when Int32 doesn't?""" +import torch +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cute.typing import Float32, Int32 +import sys + +test = sys.argv[1] if len(sys.argv) > 1 else "f32_store" + +@cute.kernel +def f32_store(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == Float32(0): + x = cute.arch.load(inp.iterator, Float32) + cute.arch.store(out.iterator, x) + +@cute.kernel +def i32_store(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == Int32(0): + cute.arch.store(out.iterator, Int32(42)) + +KERNELS = {"f32_store": f32_store, "i32_store": i32_store} +k = KERNELS[test] + +if __name__ == "__main__": + if test == "f32_store": + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + o = torch.zeros(1, dtype=torch.float32, device='cuda') + else: + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + o = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0) + print(f"Test: {test}") + compiled = cute.compile(k, xc, oc) + compiled(xc, oc) + print(f"Result: {o.item()}")