32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
"""Minimal test: just the threshold rounding function in CuTeDSL."""
|
|
import torch
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
from dsv4.kernels.gemm.fp4_quant import round_rne_u0_8
|
|
|
|
|
|
@cute.kernel
|
|
def threshold_test_kernel(
|
|
input_f32: cute.Tensor, # (1,) Float32 input
|
|
output_i32: cute.Tensor, # (1,) Int32 output
|
|
):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
if tidx == cutlass.Int32(0):
|
|
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
|
result = round_rne_u0_8(x)
|
|
cute.arch.store(output_i32.iterator, result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
|
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
|
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
|
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
|
|
|
print("Compiling...")
|
|
compiled = cute.compile(threshold_test_kernel, xc, oc)
|
|
print("Running...")
|
|
compiled(xc, oc)
|
|
print(f'round(3.7) = {out.item()} (expected 4)')
|