test: debug nvvm.inline_ptx with CUTLASS_LOG_LEVEL=DEBUG
This commit is contained in:
47
tests/unit/test_ptx_debug.py
Normal file
47
tests/unit/test_ptx_debug.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test: try nvvm.inline_ptx with extra debug info."""
|
||||
import os
|
||||
os.environ['CUTLASS_LOG_LEVEL'] = 'DEBUG'
|
||||
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def minimal_test_kernel(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_rni(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(minimal_test_kernel, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f'f32_to_i32_rni(3.7) = {out.item()} (expected 4)')
|
||||
87
tests/unit/test_ptx_v2.py
Normal file
87
tests/unit/test_ptx_v2.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test: try different approaches to nvvm.inline_ptx wrapping."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
# Approach 1: Return raw MLIR value, wrap at call site
|
||||
@dsl_user_op
|
||||
def f32_to_i32_raw(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
# nvvm.inline_ptx returns a Value; Int32() should wrap it
|
||||
return Int32(result)
|
||||
|
||||
|
||||
# Approach 2: Use nvvm.inline_ptx with two outputs (matching tutorial pattern)
|
||||
# Try with has_side_effects-like pattern
|
||||
@dsl_user_op
|
||||
def f32_to_i32_v2(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
# Use the exact same pattern as the tutorial's ptx_vote_ballot_sync
|
||||
return Int32(
|
||||
nvvm.inline_ptx(
|
||||
[T.i32()],
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_kernel_v1(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_raw(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_kernel_v2(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_v2(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("=== Test V1 (raw result) ===")
|
||||
try:
|
||||
compiled = cute.compile(test_kernel_v1, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f'V1: f32_to_i32(3.7) = {out.item()}')
|
||||
except Exception as e:
|
||||
print(f'V1 FAILED: {e}')
|
||||
|
||||
out.zero_()
|
||||
|
||||
print("\n=== Test V2 (list-style args) ===")
|
||||
try:
|
||||
compiled = cute.compile(test_kernel_v2, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f'V2: f32_to_i32(3.7) = {out.item()}')
|
||||
except Exception as e:
|
||||
print(f'V2 FAILED: {e}')
|
||||
Reference in New Issue
Block a user