test: f32 vs i32 GMEM store
This commit is contained in:
38
tests/unit/test_dtype_store.py
Normal file
38
tests/unit/test_dtype_store.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Test: does Float32 GMEM store work when Int32 doesn't?"""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
import sys
|
||||
|
||||
test = sys.argv[1] if len(sys.argv) > 1 else "f32_store"
|
||||
|
||||
@cute.kernel
|
||||
def f32_store(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Float32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
cute.arch.store(out.iterator, x)
|
||||
|
||||
@cute.kernel
|
||||
def i32_store(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
cute.arch.store(out.iterator, Int32(42))
|
||||
|
||||
KERNELS = {"f32_store": f32_store, "i32_store": i32_store}
|
||||
k = KERNELS[test]
|
||||
|
||||
if __name__ == "__main__":
|
||||
if test == "f32_store":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.float32, device='cuda')
|
||||
else:
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Test: {test}")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
Reference in New Issue
Block a user