diff --git a/tests/unit/test_ultra_minimal.py b/tests/unit/test_ultra_minimal.py new file mode 100644 index 00000000..60e8f890 --- /dev/null +++ b/tests/unit/test_ultra_minimal.py @@ -0,0 +1,31 @@ +"""Ultra-minimal test: Float32 comparison + Int32 assignment in CuTeDSL.""" +import torch +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch + + +@cute.kernel +def ultra_minimal_kernel( + input_f32: cute.Tensor, + output_i32: cute.Tensor, +): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cutlass.Float32(3.7) # no load, just a constant + r = cutlass.Int32(0) + if x > cutlass.Float32(2.0): + r = cutlass.Int32(1) + cute.arch.store(output_i32.iterator, r) + + +if __name__ == "__main__": + out = torch.zeros(1, dtype=torch.int32, device='cuda') + dummy = torch.zeros(1, dtype=torch.float32, device='cuda') + dc = cutlass_torch.from_dlpack(dummy).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0) + print("Compiling...") + compiled = cute.compile(ultra_minimal_kernel, dc, oc) + print("Running...") + compiled(dc, oc) + print(f'Result: {out.item()} (expected 1)')