test: different constraint strings + bitcast approach

This commit is contained in:
2026-05-28 04:50:09 +00:00
parent 4806e9ba11
commit c55c237fcd

View File

@@ -0,0 +1,87 @@
"""Test: try different constraint strings for llvm.inline_asm cvt.rni."""
import torch
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
from cutlass.cutlass_dsl import dsl_user_op
from cutlass._mlir.dialects import llvm
from cutlass.cute.typing import Float32, Int32
import sys
approach = sys.argv[1] if len(sys.argv) > 1 else "r_r"
# Try "=r,r" (both as general 32-bit registers)
@dsl_user_op
def f32_to_i32_r_r(x: Float32, *, loc=None, ip=None) -> Int32:
val_i32 = llvm.inline_asm(
Int32._mlir_type(),
[Float32(x).ir_value(loc=loc, ip=ip)],
"cvt.rni.s32.f32 $0, $1;",
"=r,r",
has_side_effects=False,
is_align_stack=False,
loc=loc,
ip=ip,
)
return Int32(val_i32)
# Try bitcast approach: treat float as int, then do integer operations
@dsl_user_op
def f32_bitcast_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
# Bitcast float to int (reinterpret bits, not convert)
val_i32 = llvm.bitcast(Int32._mlir_type(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
return Int32(val_i32)
# Try: floor(x) via cute.floor, then bitcast - no, floor returns float
# Try: truncate via cute.arch operations
# Actually, let's try: use llvm.inline_asm with just integer registers
# The idea: bitcast float to i32, then in PTX re-interpret as float and cvt
@dsl_user_op
def f32_to_i32_via_bitcast(x: Float32, *, loc=None, ip=None) -> Int32:
# Bitcast float bits to int, then PTX mov + cvt
bits = Float32(x).ir_value(loc=loc, ip=ip)
val_i32 = llvm.inline_asm(
Int32._mlir_type(),
[bits],
"{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
"=r,f",
has_side_effects=False,
is_align_stack=False,
loc=loc,
ip=ip,
)
return Int32(val_i32)
FUNCS = {
"r_r": f32_to_i32_r_r,
"bitcast": f32_bitcast_to_i32,
"via_bitcast": f32_to_i32_via_bitcast,
}
func = FUNCS[approach]
@cute.kernel
def test_k(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == Int32(0):
x = cute.arch.load(inp.iterator, Float32)
r = func(x)
cute.arch.store(out.iterator, r)
if __name__ == "__main__":
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
o = torch.zeros(1, dtype=torch.int32, device='cuda')
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
print(f"Approach: {approach}")
print("Compiling...")
compiled = cute.compile(test_k, xc, oc)
print("Running...")
compiled(xc, oc)
print(f"Result: {o.item()}")