NVFP4-1.1: ultra-minimal test — Float32 comparison + Int32 select
This commit is contained in:
31
tests/unit/test_ultra_minimal.py
Normal file
31
tests/unit/test_ultra_minimal.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Ultra-minimal test: Float32 comparison + Int32 assignment in CuTeDSL."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def ultra_minimal_kernel(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cutlass.Float32(3.7) # no load, just a constant
|
||||
r = cutlass.Int32(0)
|
||||
if x > cutlass.Float32(2.0):
|
||||
r = cutlass.Int32(1)
|
||||
cute.arch.store(output_i32.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
dummy = torch.zeros(1, dtype=torch.float32, device='cuda')
|
||||
dc = cutlass_torch.from_dlpack(dummy).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(ultra_minimal_kernel, dc, oc)
|
||||
print("Running...")
|
||||
compiled(dc, oc)
|
||||
print(f'Result: {out.item()} (expected 1)')
|
||||
Reference in New Issue
Block a user