test: sub-process isolation for each f32→i32 approach
This commit is contained in:
18
tests/unit/test_ptx_runner.py
Normal file
18
tests/unit/test_ptx_runner.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Runner: test each f32→i32 approach in a separate process."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
for approach in ["arith", "nvvm_ptx", "nvvm_multiline"]:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Testing: {approach}")
|
||||
print(f"{'='*50}")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "test_ptx_subproc.py", approach],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(f"STDERR: {result.stderr[:500]}")
|
||||
print(f"Exit code: {result.returncode}")
|
||||
76
tests/unit/test_ptx_subproc.py
Normal file
76
tests/unit/test_ptx_subproc.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Test: run each approach in a separate process to survive LLVM ERROR."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import sys
|
||||
|
||||
approach = sys.argv[1] if len(sys.argv) > 1 else "arith"
|
||||
|
||||
if approach == "arith":
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import arith
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
return Int32(
|
||||
arith.fptosi(T.i32(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
)
|
||||
|
||||
elif approach == "nvvm_ptx":
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
elif approach == "nvvm_multiline":
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
else:
|
||||
print(f"Unknown approach: {approach}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = f32_to_i32(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Approach: {approach}")
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(test_k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()} (expected 4)")
|
||||
Reference in New Issue
Block a user