NVFP4-1.1: add Int32 cast diagnostic test

This commit is contained in:
2026-05-28 03:59:01 +00:00
parent e565ebce91
commit a05a76bb6b

View File

@@ -0,0 +1,31 @@
"""Quick test: Float32 to Int32 conversion in CuTeDSL."""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
@cute.kernel
def test_int32_cast_kernel(
input_f32: cute.Tensor,
output_i32: cute.Tensor,
):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
f = cute.arch.load(input_f32.iterator, cutlass.Float32)
i = cutlass.Int32(f) # float-to-int conversion
cute.arch.store(output_i32.iterator, i, cutlass.Int32)
if __name__ == "__main__":
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
out = torch.zeros(1, dtype=torch.int32, device='cuda')
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling...")
compiled = cute.compile(test_int32_cast_kernel, xc, oc)
print("Running...")
compiled(xc, oc)
print(f'Result: {out.item()} (3=trunc, 4=RNE)')