NVFP4-1.1: add Int32 cast diagnostic test
This commit is contained in:
31
tests/unit/test_int32_cast.py
Normal file
31
tests/unit/test_int32_cast.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Quick test: Float32 to Int32 conversion in CuTeDSL."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_int32_cast_kernel(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
f = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
i = cutlass.Int32(f) # float-to-int conversion
|
||||
cute.arch.store(output_i32.iterator, i, cutlass.Int32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
import cuda.bindings.driver as cuda
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(test_int32_cast_kernel, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f'Result: {out.item()} (3=trunc, 4=RNE)')
|
||||
Reference in New Issue
Block a user