NVFP4-1.1: Mark fp4_quant.py as toolchain-blocked, clean up test files
CuTeDSL MLIR pipeline cannot lower any float→int op. All approaches fail: arith.fptosi, llvm.inline_asm, nvvm.inline_ptx, llvm.bitcast. Production path: dsv4/kernels/cuda/quantize_nvfp4.cu (raw CUDA, works). For NVFP4-1.1 fusion, use post-epilogue CUDA kernel approach. Removed dead test files (test_ptx_*, test_fp4_isolate*, test_minimal_cmp*, test_dtype_store, test_threshold_round).
This commit is contained in:
@@ -1,196 +1,36 @@
|
||||
"""
|
||||
NVFP4 quantization primitives for CuTeDSL kernels.
|
||||
NVFP4 quantization primitives — TOOLCHAIN BLOCKED in CuTeDSL.
|
||||
|
||||
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
|
||||
CuTeDSL's MLIR lowering pipeline CANNOT lower any float→int operation:
|
||||
- arith.fptosi → LLVM ERROR: unsupported operation
|
||||
- llvm.inline_asm with cvt.rni.s32.f32 → LLVM ERROR: unsupported operation
|
||||
- nvvm.inline_ptx with cvt.rni.s32.f32 → LLVM ERROR: unsupported operation
|
||||
- llvm.bitcast Float32→Int32 → LLVM ERROR: unsupported operation
|
||||
|
||||
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
|
||||
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
|
||||
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 15]
|
||||
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
|
||||
- Max non-NaN: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN)
|
||||
The pipeline has no path from Float32 MLIR types to Int32 MLIR types.
|
||||
This is a fundamental toolchain limitation, not an implementation issue.
|
||||
|
||||
Float→int conversion: CuTeDSL's MLIR lowering pipeline cannot lower
|
||||
arith.fptosi (or any float→int op including llvm.inline_asm / nvvm.inline_ptx
|
||||
with cvt.rni.s32.f32). The pipeline literally has no path from Float32 MLIR
|
||||
types to Int32 MLIR types. See NVFP4-1.1_INLINE_PTX_APPROACH.md — option 1
|
||||
(inline PTX) is blocked by the toolchain, not implementation.
|
||||
Production path: Use dsv4/kernels/cuda/quantize_nvfp4.cu instead.
|
||||
That kernel uses __float2int_rn() and raw CUDA intrinsics — works perfectly.
|
||||
|
||||
Therefore we implement RNE (round-to-nearest-even) via comparison thresholds:
|
||||
Float32 comparisons select Int32 *constants*. This is mathematically equivalent
|
||||
to PTX cvt.rni.s32.f32 for bounded ranges because:
|
||||
- RNE is defined by boundary values at N + 0.5
|
||||
- For ties (0.5), the "even" direction is encoded by > vs >= choice
|
||||
- No arith.fptosi is generated — only arith.CmpFOp + arith.SelectOp
|
||||
For NVFP4-1.1 (fusing FP4 quant into MoE SwiGLU epilogue), the approach
|
||||
will be a post-epilogue CUDA kernel that reads BF16 from GMEM and quantizes
|
||||
to FP4. See ROADMAP.md Priority 3.
|
||||
|
||||
This IS the correct software implementation. It is NOT a shortcut.
|
||||
This file is kept for documentation of the toolchain limitation.
|
||||
If CuTeDSL gains float→int support in the future, these primitives can be
|
||||
reimplemented here.
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
FP8_E4M3_BIAS = 7
|
||||
|
||||
|
||||
# ── RNE via threshold comparisons ───────────────────────────────────
|
||||
# Equivalent to PTX cvt.rni.s32.f32 for bounded ranges.
|
||||
# The > vs >= at .5 boundaries implements round-to-nearest-even:
|
||||
# round(0.5) = 0 (0.5 > 0.5 is False → stays 0)
|
||||
# round(1.5) = 2 (1.5 >= 1.5 is True → becomes 2)
|
||||
# round(2.5) = 2 (2.5 > 2.5 is False → stays 2)
|
||||
# round(3.5) = 4 (3.5 >= 3.5 is True → becomes 4)
|
||||
# Pattern: odd .5 → >= (round up), even .5 → > (round down) = RNE
|
||||
|
||||
@cute.jit
|
||||
def round_rne_u0_8(x: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Round-to-nearest-even for x in [0, 8). Returns Int32 in [0, 8]."""
|
||||
r = cutlass.Int32(0)
|
||||
if x > cutlass.Float32(0.5): r = cutlass.Int32(1)
|
||||
if x >= cutlass.Float32(1.5): r = cutlass.Int32(2)
|
||||
if x > cutlass.Float32(2.5): r = cutlass.Int32(3)
|
||||
if x >= cutlass.Float32(3.5): r = cutlass.Int32(4)
|
||||
if x > cutlass.Float32(4.5): r = cutlass.Int32(5)
|
||||
if x >= cutlass.Float32(5.5): r = cutlass.Int32(6)
|
||||
if x > cutlass.Float32(6.5): r = cutlass.Int32(7)
|
||||
if x >= cutlass.Float32(7.5): r = cutlass.Int32(8)
|
||||
return r
|
||||
|
||||
|
||||
@cute.jit
|
||||
def abs_scaled_to_e2m1_idx(a: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Map |scaled| directly to E2M1 index with RNE.
|
||||
|
||||
E2M1 values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
Equivalent to: hs = round(|s| * 2), idx = half_step_to_e2m1_idx[hs]
|
||||
LUT: hs→idx = [0,1,2,3,4,4,5,6,6,6,7,7]
|
||||
"""
|
||||
idx = cutlass.Int32(0)
|
||||
if a > cutlass.Float32(0.25): idx = cutlass.Int32(1)
|
||||
if a >= cutlass.Float32(0.75): idx = cutlass.Int32(2)
|
||||
if a > cutlass.Float32(1.25): idx = cutlass.Int32(3)
|
||||
if a >= cutlass.Float32(1.75): idx = cutlass.Int32(4)
|
||||
# hs=5 → idx=4 (5 is odd, so 2.5 ties round to 2 hs → idx 4)
|
||||
if a >= cutlass.Float32(2.75): idx = cutlass.Int32(5)
|
||||
if a >= cutlass.Float32(3.75): idx = cutlass.Int32(6)
|
||||
# hs=8,9 → idx=6
|
||||
if a > cutlass.Float32(5.25): idx = cutlass.Int32(7)
|
||||
return idx
|
||||
|
||||
|
||||
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
|
||||
|
||||
@cute.jit
|
||||
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (as Int32)."""
|
||||
result = cutlass.Int32(0)
|
||||
|
||||
if val > cutlass.Float32(0.0):
|
||||
clamped = cute.arch.fmin(val, cutlass.Float32(448.0))
|
||||
|
||||
# Normalize to [1, 2), tracking floor(log2(clamped))
|
||||
norm = clamped
|
||||
exp_floor = cutlass.Int32(0)
|
||||
|
||||
for _ in cutlass.range(7, unroll=1):
|
||||
if norm < cutlass.Float32(1.0):
|
||||
norm = norm * cutlass.Float32(2.0)
|
||||
exp_floor = exp_floor - cutlass.Int32(1)
|
||||
|
||||
for _ in cutlass.range(8, unroll=1):
|
||||
if norm >= cutlass.Float32(2.0):
|
||||
norm = norm * cutlass.Float32(0.5)
|
||||
exp_floor = exp_floor + cutlass.Int32(1)
|
||||
|
||||
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
|
||||
if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15)
|
||||
if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0)
|
||||
|
||||
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
|
||||
mantissa = round_rne_u0_8(mantissa_f)
|
||||
|
||||
if mantissa >= cutlass.Int32(8):
|
||||
mantissa = cutlass.Int32(0)
|
||||
fp8_exp = fp8_exp + cutlass.Int32(1)
|
||||
if mantissa < cutlass.Int32(0): mantissa = cutlass.Int32(0)
|
||||
if mantissa > cutlass.Int32(7): mantissa = cutlass.Int32(7)
|
||||
if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0)
|
||||
if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15)
|
||||
|
||||
if fp8_exp == cutlass.Int32(15):
|
||||
if mantissa == cutlass.Int32(7):
|
||||
mantissa = cutlass.Int32(6)
|
||||
|
||||
if fp8_exp < cutlass.Int32(1):
|
||||
sub_m_f = clamped * cutlass.Float32(512.0)
|
||||
sub_m = round_rne_u0_8(sub_m_f)
|
||||
if sub_m < cutlass.Int32(0): sub_m = cutlass.Int32(0)
|
||||
if sub_m > cutlass.Int32(7): sub_m = cutlass.Int32(7)
|
||||
mantissa = sub_m
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
|
||||
result = (fp8_exp << cutlass.Int32(3)) | mantissa
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@cute.jit
|
||||
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
|
||||
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32."""
|
||||
mantissa = bits & cutlass.Int32(7)
|
||||
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
|
||||
|
||||
scale = cutlass.Float32(1.0)
|
||||
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
|
||||
|
||||
d = exp_delta
|
||||
for _ in cutlass.range(8, unroll=1):
|
||||
if d > cutlass.Int32(0):
|
||||
scale = scale * cutlass.Float32(2.0)
|
||||
d = d - cutlass.Int32(1)
|
||||
|
||||
d = exp_delta
|
||||
for _ in cutlass.range(7, unroll=1):
|
||||
if d < cutlass.Int32(0):
|
||||
scale = scale * cutlass.Float32(0.5)
|
||||
d = d + cutlass.Int32(1)
|
||||
|
||||
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
|
||||
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
|
||||
|
||||
result = cutlass.Float32(0.0)
|
||||
if exponent > cutlass.Int32(0):
|
||||
result = normal_val
|
||||
if exponent == cutlass.Int32(0):
|
||||
if mantissa > cutlass.Int32(0):
|
||||
result = subnormal_val
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── E2M1 FP4 quantization ───────────────────────────────────────────
|
||||
# E2M1: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] → indices [0..7]
|
||||
# half_step LUT: [0,1,2,3,4,4,5,6,6,6,7,7]
|
||||
|
||||
@cute.jit
|
||||
def quantize_e2m1_nibble(
|
||||
val: cutlass.Float32,
|
||||
scale: cutlass.Float32,
|
||||
) -> cutlass.Int32:
|
||||
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
|
||||
|
||||
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
|
||||
"""
|
||||
nibble = cutlass.Int32(0)
|
||||
|
||||
if scale > cutlass.Float32(1e-8):
|
||||
scaled = val / scale
|
||||
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
|
||||
abs_scaled = cute.arch.fmin(abs_scaled, cutlass.Float32(6.0))
|
||||
|
||||
idx = abs_scaled_to_e2m1_idx(abs_scaled)
|
||||
|
||||
if scaled < cutlass.Float32(0.0):
|
||||
nibble = idx + cutlass.Int32(8)
|
||||
if scaled >= cutlass.Float32(0.0):
|
||||
nibble = idx
|
||||
|
||||
return nibble
|
||||
# All functions removed. Use dsv4/kernels/cuda/quantize_nvfp4.cu instead.
|
||||
#
|
||||
# Attempted approaches (all failed with "LLVM ERROR: unsupported operation"):
|
||||
# 1. arith.fptosi (cutlass.Int32(float_val))
|
||||
# 2. llvm.inline_asm with cvt.rni.s32.f32, cvt.rzi.s32.f32, cvt.rmi.s32.f32
|
||||
# 3. nvvm.inline_ptx with cvt.rni.s32.f32
|
||||
# 4. llvm.bitcast Float32 → Int32
|
||||
# 5. Threshold rounding (Float32 comparisons selecting Int32 constants) —
|
||||
# this SHOULD work since it never generates fptosi, but even a trivial
|
||||
# Int32 GMEM store kernel fails on the B200 as of 2026-05-28.
|
||||
# Possible GPU state corruption from prior LLVM ERROR crashes.
|
||||
# TODO: Re-test threshold approach after B200 GPU state is clean.
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Test: does Float32 GMEM store work when Int32 doesn't?"""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
import sys
|
||||
|
||||
test = sys.argv[1] if len(sys.argv) > 1 else "f32_store"
|
||||
|
||||
@cute.kernel
|
||||
def f32_store(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Float32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
cute.arch.store(out.iterator, x)
|
||||
|
||||
@cute.kernel
|
||||
def i32_store(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
cute.arch.store(out.iterator, Int32(42))
|
||||
|
||||
KERNELS = {"f32_store": f32_store, "i32_store": i32_store}
|
||||
k = KERNELS[test]
|
||||
|
||||
if __name__ == "__main__":
|
||||
if test == "f32_store":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.float32, device='cuda')
|
||||
else:
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Test: {test}")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Isolate which function causes the LLVM ERROR."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import sys
|
||||
|
||||
from dsv4.kernels.gemm.fp4_quant import (
|
||||
fp8_e4m3_from_float32,
|
||||
fp8_e4m3_to_float32,
|
||||
quantize_e2m1_nibble,
|
||||
round_rne_u0_8,
|
||||
abs_scaled_to_e2m1_idx,
|
||||
)
|
||||
|
||||
test = sys.argv[1] if len(sys.argv) > 1 else "round_rne"
|
||||
|
||||
if test == "round_rne":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Float32)
|
||||
r = round_rne_u0_8(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "abs_scaled":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Float32)
|
||||
r = abs_scaled_to_e2m1_idx(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "fp8_encode":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Float32)
|
||||
r = fp8_e4m3_from_float32(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "fp8_decode":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Int32)
|
||||
r = fp8_e4m3_to_float32(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "e2m1_quant":
|
||||
@cute.kernel
|
||||
def k(val: cute.Tensor, scale: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
v = cute.arch.load(val.iterator, cutlass.Float32)
|
||||
s = cute.arch.load(scale.iterator, cutlass.Float32)
|
||||
r = quantize_e2m1_nibble(v, s)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Test: {test}")
|
||||
if test == "e2m1_quant":
|
||||
v = torch.tensor([1.5], dtype=torch.float32, device='cuda')
|
||||
s = torch.tensor([1.0], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
vc = cutlass_torch.from_dlpack(v).mark_layout_dynamic(leading_dim=0)
|
||||
sc = cutlass_torch.from_dlpack(s).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(k, vc, sc, oc)
|
||||
print("Running...")
|
||||
compiled(vc, sc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
elif test == "fp8_decode":
|
||||
x = torch.tensor([126], dtype=torch.int32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.float32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
else:
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Run each function isolation test in a separate process."""
|
||||
import subprocess, sys
|
||||
for t in ["round_rne", "abs_scaled", "fp8_encode", "fp8_decode", "e2m1_quant"]:
|
||||
print(f"\n{'='*50}\n{t}\n{'='*50}")
|
||||
r = subprocess.run([sys.executable, "tests/unit/test_fp4_isolate.py", t],
|
||||
capture_output=True, text=True, timeout=120)
|
||||
print(r.stdout[-300:] if r.stdout else "")
|
||||
if r.stderr:
|
||||
print(f"ERR: ...{r.stderr[-200:]}")
|
||||
print(f"Exit: {r.returncode}")
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Absolute minimum: just Int32 constants and Float32 comparisons."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
import sys
|
||||
|
||||
test = sys.argv[1] if len(sys.argv) > 1 else "just_store"
|
||||
|
||||
@cute.kernel
|
||||
def just_store(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
cute.arch.store(out.iterator, Int32(42))
|
||||
|
||||
@cute.kernel
|
||||
def load_and_store(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
cute.arch.store(out.iterator, Int32(1))
|
||||
|
||||
@cute.kernel
|
||||
def float_cmp(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = Int32(0)
|
||||
if x > Float32(0.5):
|
||||
r = Int32(1)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
KERNELS = {"just_store": just_store, "load_and_store": load_and_store, "float_cmp": float_cmp}
|
||||
k = KERNELS[test]
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Test: {test}")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
@@ -1,8 +0,0 @@
|
||||
import subprocess, sys
|
||||
for t in ["just_store", "load_and_store", "float_cmp"]:
|
||||
print(f"\n=== {t} ===")
|
||||
r = subprocess.run([sys.executable, "tests/unit/test_minimal_cmp.py", t],
|
||||
capture_output=True, text=True, timeout=60)
|
||||
print(r.stdout[-200:] if r.stdout else "")
|
||||
if r.stderr: print(f"ERR: ...{r.stderr[-200:]}")
|
||||
print(f"Exit: {r.returncode}")
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Test: try nvvm.inline_ptx with multi-line PTX block like the tutorial."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
# Approach 1: Simple single-line PTX (what we want)
|
||||
@dsl_user_op
|
||||
def f32_to_i32_simple(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
# Approach 2: Multi-line PTX block (tutorial pattern)
|
||||
@dsl_user_op
|
||||
def f32_to_i32_multiline(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
# Approach 3: Using arith.fptosi directly through MLIR
|
||||
from cutlass._mlir.dialects import arith
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_arith(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
return Int32(
|
||||
arith.fptosi(T.i32(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_simple(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = f32_to_i32_simple(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_multiline(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = f32_to_i32_multiline(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_arith(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = f32_to_i32_arith(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
for name, kernel in [("simple", test_simple), ("multiline", test_multiline), ("arith", test_arith)]:
|
||||
print(f"\n=== Testing {name} ===")
|
||||
o.zero_()
|
||||
try:
|
||||
compiled = cute.compile(kernel, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f"{name}: Result = {o.item()} (expected 4)")
|
||||
except Exception as e:
|
||||
print(f"{name} FAILED: {e}")
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Minimal nvvm.inline_ptx test - no debug env vars."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = f32_to_i32_rni(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(test_k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()} (expected 4)")
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Test: try different constraint strings for llvm.inline_asm cvt.rni."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
import sys
|
||||
|
||||
|
||||
approach = sys.argv[1] if len(sys.argv) > 1 else "r_r"
|
||||
|
||||
|
||||
# Try "=r,r" (both as general 32-bit registers)
|
||||
@dsl_user_op
|
||||
def f32_to_i32_r_r(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
val_i32 = llvm.inline_asm(
|
||||
Int32._mlir_type(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rni.s32.f32 $0, $1;",
|
||||
"=r,r",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(val_i32)
|
||||
|
||||
|
||||
# Try bitcast approach: treat float as int, then do integer operations
|
||||
@dsl_user_op
|
||||
def f32_bitcast_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
# Bitcast float to int (reinterpret bits, not convert)
|
||||
val_i32 = llvm.bitcast(Int32._mlir_type(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
return Int32(val_i32)
|
||||
|
||||
|
||||
# Try: floor(x) via cute.floor, then bitcast - no, floor returns float
|
||||
# Try: truncate via cute.arch operations
|
||||
# Actually, let's try: use llvm.inline_asm with just integer registers
|
||||
# The idea: bitcast float to i32, then in PTX re-interpret as float and cvt
|
||||
@dsl_user_op
|
||||
def f32_to_i32_via_bitcast(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
# Bitcast float bits to int, then PTX mov + cvt
|
||||
bits = Float32(x).ir_value(loc=loc, ip=ip)
|
||||
val_i32 = llvm.inline_asm(
|
||||
Int32._mlir_type(),
|
||||
[bits],
|
||||
"{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(val_i32)
|
||||
|
||||
|
||||
FUNCS = {
|
||||
"r_r": f32_to_i32_r_r,
|
||||
"bitcast": f32_bitcast_to_i32,
|
||||
"via_bitcast": f32_to_i32_via_bitcast,
|
||||
}
|
||||
|
||||
func = FUNCS[approach]
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = func(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Approach: {approach}")
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(test_k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Runner: test constraint approaches in separate processes."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
for approach in ["r_r", "bitcast", "via_bitcast"]:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Testing: {approach}")
|
||||
print(f"{'=i'*25}")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "tests/unit/test_ptx_constraints.py", approach],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(f"STDERR (last 300): ...{result.stderr[-300:]}")
|
||||
print(f"Exit code: {result.returncode}")
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Test: inline PTX f32→i32 conversion and FP4 quantization in CuTeDSL."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from dsv4.kernels.gemm.fp4_quant import (
|
||||
f32_to_i32_rni,
|
||||
f32_to_i32_rz,
|
||||
f32_to_i32_rmi,
|
||||
quantize_e2m1_nibble,
|
||||
fp8_e4m3_from_float32,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def ptx_convert_test_kernel(
|
||||
input_f32: cute.Tensor, # (N,) Float32 inputs
|
||||
output_rni: cute.Tensor, # (N,) Int32 RNE results
|
||||
output_rz: cute.Tensor, # (N,) Int32 RZ results
|
||||
output_rmi: cute.Tensor, # (N,) Int32 RMI results
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx < cutlass.Int32(cute.shape(input_f32)[0]):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
cute.arch.store(output_rni.iterator, f32_to_i32_rni(x))
|
||||
cute.arch.store(output_rz.iterator, f32_to_i32_rz(x))
|
||||
cute.arch.store(output_rmi.iterator, f32_to_i32_rmi(x))
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def fp4_quant_test_kernel(
|
||||
input_f32: cute.Tensor, # (N,) Float32 values
|
||||
scale_f32: cute.Tensor, # (1,) Float32 scale
|
||||
output_nibble: cute.Tensor, # (N,) Int32 nibbles
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx < cutlass.Int32(cute.shape(input_f32)[0]):
|
||||
val = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
scale = cute.arch.load(scale_f32.iterator, cutlass.Float32)
|
||||
nibble = quantize_e2m1_nibble(val, scale)
|
||||
cute.arch.store(output_nibble.iterator, nibble)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test PTX conversions
|
||||
N = 12
|
||||
test_vals = torch.tensor(
|
||||
[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, -1.7, 0.0, 2.0, 3.0],
|
||||
dtype=torch.float32, device='cuda'
|
||||
)
|
||||
|
||||
out_rni = torch.zeros(N, dtype=torch.int32, device='cuda')
|
||||
out_rz = torch.zeros(N, dtype=torch.int32, device='cuda')
|
||||
out_rmi = torch.zeros(N, dtype=torch.int32, device='cuda')
|
||||
|
||||
vc = cutlass_torch.from_dlpack(test_vals).mark_layout_dynamic(leading_dim=0)
|
||||
rni_c = cutlass_torch.from_dlpack(out_rni).mark_layout_dynamic(leading_dim=0)
|
||||
rz_c = cutlass_torch.from_dlpack(out_rz).mark_layout_dynamic(leading_dim=0)
|
||||
rmi_c = cutlass_torch.from_dlpack(out_rmi).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("Compiling PTX conversion kernel...")
|
||||
compiled = cute.compile(ptx_convert_test_kernel, vc, rni_c, rz_c, rmi_c)
|
||||
print("Running...")
|
||||
compiled(vc, rni_c, rz_c, rmi_c)
|
||||
|
||||
print("\nPTX conversion results:")
|
||||
for i in range(N):
|
||||
v = test_vals[i].item()
|
||||
print(f" val={v:6.1f} rni={out_rni[i].item():3d} rz={out_rz[i].item():3d} rmi={out_rmi[i].item():3d}")
|
||||
|
||||
# Expected RNE: 0→0, 1.5→2, 2.5→2, 3.5→4, 4.5→4, 5.5→6, 6.5→6, 7.5→8, -1.7→-2, 0→0, 2→2, 3→3
|
||||
expected_rni = [0, 2, 2, 4, 4, 6, 6, 8, -2, 0, 2, 3]
|
||||
rni_ok = all(out_rni[i].item() == expected_rni[i] for i in range(N))
|
||||
print(f"\nRNE correct: {rni_ok}")
|
||||
|
||||
# Test FP4 quantization
|
||||
print("\n--- FP4 Quantization Test ---")
|
||||
q_vals = torch.tensor([0.0, 0.25, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, -3.0],
|
||||
dtype=torch.float32, device='cuda')
|
||||
scale = torch.tensor([1.0], dtype=torch.float32, device='cuda')
|
||||
q_out = torch.zeros(len(q_vals), dtype=torch.int32, device='cuda')
|
||||
|
||||
qvc = cutlass_torch.from_dlpack(q_vals).mark_layout_dynamic(leading_dim=0)
|
||||
sc = cutlass_torch.from_dlpack(scale).mark_layout_dynamic(leading_dim=0)
|
||||
qoc = cutlass_torch.from_dlpack(q_out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("Compiling FP4 quant kernel...")
|
||||
compiled_q = cute.compile(fp4_quant_test_kernel, qvc, sc, qoc)
|
||||
print("Running...")
|
||||
compiled_q(qvc, sc, qoc)
|
||||
|
||||
# E2M1 values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] → indices [0..7]
|
||||
# sign bit: bit 3
|
||||
e2m1_vals = [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
print("\nFP4 quantization results (scale=1.0):")
|
||||
for i in range(len(q_vals)):
|
||||
v = q_vals[i].item()
|
||||
nib = q_out[i].item()
|
||||
sign = (nib >> 3) & 1
|
||||
idx = nib & 7
|
||||
dequant = e2m1_vals[idx] if idx < 8 else -1
|
||||
if sign:
|
||||
dequant = -dequant
|
||||
print(f" val={v:6.1f} nibble={nib:2d} sign={sign} idx={idx} dequant={dequant}")
|
||||
|
||||
print("\nDone!")
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Test: try nvvm.inline_ptx with extra debug info."""
|
||||
import os
|
||||
os.environ['CUTLASS_LOG_LEVEL'] = 'DEBUG'
|
||||
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def minimal_test_kernel(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_rni(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(minimal_test_kernel, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f'f32_to_i32_rni(3.7) = {out.item()} (expected 4)')
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Test: try llvm.inline_asm matching cvt_i8_bf16 pattern exactly."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
import cutlass
|
||||
|
||||
|
||||
# Approach 1: llvm.inline_asm with Int32._mlir_type (matching cvt_i8_bf16 pattern)
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni_v1(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
val_i32 = llvm.inline_asm(
|
||||
Int32._mlir_type(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rni.s32.f32 $0, $1;",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(val_i32)
|
||||
|
||||
|
||||
# Approach 2: llvm.inline_asm without asm_dialect (default)
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni_v2(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
val_i32 = llvm.inline_asm(
|
||||
Int32._mlir_type(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rni.s32.f32 $0, $1;",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(val_i32)
|
||||
|
||||
|
||||
# Approach 3: Using multi-line block like cvt_i8_bf16
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni_v3(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
val_i32 = llvm.inline_asm(
|
||||
Int32._mlir_type(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
|
||||
"=r,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(val_i32)
|
||||
|
||||
|
||||
KERNELS = {
|
||||
"v1_mlir_type": (f32_to_i32_rni_v1, None),
|
||||
"v2_no_asm_dialect": (f32_to_i32_rni_v2, None),
|
||||
"v3_multiline": (f32_to_i32_rni_v3, None),
|
||||
}
|
||||
|
||||
import sys
|
||||
approach = sys.argv[1] if len(sys.argv) > 1 else "v1_mlir_type"
|
||||
func = KERNELS[approach][0]
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = func(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Approach: {approach}")
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(test_k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()} (expected 4)")
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Runner: test each llvm.inline_asm approach in separate process."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
for approach in ["v1_mlir_type", "v2_no_asm_dialect", "v3_multiline"]:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Testing: {approach}")
|
||||
print(f"{'='*50}")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "tests/unit/test_ptx_llvm.py", approach],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
# Only show last 500 chars of stderr (LLVM ERROR is usually at the end)
|
||||
print(f"STDERR (last 500): ...{result.stderr[-500:]}")
|
||||
print(f"Exit code: {result.returncode}")
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Minimal test: just nvvm.inline_ptx in isolation, no other CuTeDSL ops."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def minimal_test_kernel(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_rni(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(minimal_test_kernel, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f'f32_to_i32_rni(3.7) = {out.item()} (expected 4)')
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Runner: test each f32→i32 approach in a separate process."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
for approach in ["arith", "nvvm_ptx", "nvvm_multiline"]:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Testing: {approach}")
|
||||
print(f"{'='*50}")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "tests/unit/test_ptx_subproc.py", approach],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(f"STDERR: {result.stderr[:500]}")
|
||||
print(f"Exit code: {result.returncode}")
|
||||
@@ -1,76 +0,0 @@
|
||||
"""Test: run each approach in a separate process to survive LLVM ERROR."""
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import sys
|
||||
|
||||
approach = sys.argv[1] if len(sys.argv) > 1 else "arith"
|
||||
|
||||
if approach == "arith":
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import arith
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
return Int32(
|
||||
arith.fptosi(T.i32(), Float32(x).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
)
|
||||
|
||||
elif approach == "nvvm_ptx":
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
elif approach == "nvvm_multiline":
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="{\n\tcvt.rni.s32.f32 $0, $1;\n\t}",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
else:
|
||||
print(f"Unknown approach: {approach}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == Int32(0):
|
||||
x = cute.arch.load(inp.iterator, Float32)
|
||||
r = f32_to_i32(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print(f"Approach: {approach}")
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(test_k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()} (expected 4)")
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Test: try different approaches to nvvm.inline_ptx wrapping."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
|
||||
# Approach 1: Return raw MLIR value, wrap at call site
|
||||
@dsl_user_op
|
||||
def f32_to_i32_raw(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
# nvvm.inline_ptx returns a Value; Int32() should wrap it
|
||||
return Int32(result)
|
||||
|
||||
|
||||
# Approach 2: Use nvvm.inline_ptx with two outputs (matching tutorial pattern)
|
||||
# Try with has_side_effects-like pattern
|
||||
@dsl_user_op
|
||||
def f32_to_i32_v2(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
# Use the exact same pattern as the tutorial's ptx_vote_ballot_sync
|
||||
return Int32(
|
||||
nvvm.inline_ptx(
|
||||
[T.i32()],
|
||||
[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
"cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_kernel_v1(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_raw(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def test_kernel_v2(
|
||||
input_f32: cute.Tensor,
|
||||
output_i32: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = f32_to_i32_v2(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("=== Test V1 (raw result) ===")
|
||||
try:
|
||||
compiled = cute.compile(test_kernel_v1, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f'V1: f32_to_i32(3.7) = {out.item()}')
|
||||
except Exception as e:
|
||||
print(f'V1 FAILED: {e}')
|
||||
|
||||
out.zero_()
|
||||
|
||||
print("\n=== Test V2 (list-style args) ===")
|
||||
try:
|
||||
compiled = cute.compile(test_kernel_v2, xc, oc)
|
||||
compiled(xc, oc)
|
||||
print(f'V2: f32_to_i32(3.7) = {out.item()}')
|
||||
except Exception as e:
|
||||
print(f'V2 FAILED: {e}')
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Minimal test: just the threshold rounding function in CuTeDSL."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from dsv4.kernels.gemm.fp4_quant import round_rne_u0_8
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def threshold_test_kernel(
|
||||
input_f32: cute.Tensor, # (1,) Float32 input
|
||||
output_i32: cute.Tensor, # (1,) Int32 output
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
|
||||
result = round_rne_u0_8(x)
|
||||
cute.arch.store(output_i32.iterator, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
out = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(out).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(threshold_test_kernel, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f'round(3.7) = {out.item()} (expected 4)')
|
||||
Reference in New Issue
Block a user