Files
nvfp4-megamoe-kernel/tests/unit/test_ptx_debug.py

48 lines
1.4 KiB
Python

"""Test: try nvvm.inline_ptx with extra debug info."""
import os
os.environ['CUTLASS_LOG_LEVEL'] = 'DEBUG'
import torch
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.typing import Float32, Int32
@dsl_user_op
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rni.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@cute.kernel
def minimal_test_kernel(
input_f32: cute.Tensor,
output_i32: cute.Tensor,
):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
result = f32_to_i32_rni(x)
cute.arch.store(output_i32.iterator, result)
if __name__ == "__main__":
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
out = torch.zeros(1, dtype=torch.int32, device='cuda')
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
print("Compiling...")
compiled = cute.compile(minimal_test_kernel, xc, oc)
print("Running...")
compiled(xc, oc)
print(f'f32_to_i32_rni(3.7) = {out.item()} (expected 4)')