test: isolate which fp4_quant function causes LLVM ERROR
This commit is contained in:
98
tests/unit/test_fp4_isolate.py
Normal file
98
tests/unit/test_fp4_isolate.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Isolate which function causes the LLVM ERROR."""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import sys
|
||||
|
||||
from dsv4.kernels.gemm.fp4_quant import (
|
||||
fp8_e4m3_from_float32,
|
||||
fp8_e4m3_to_float32,
|
||||
quantize_e2m1_nibble,
|
||||
round_rne_u0_8,
|
||||
abs_scaled_to_e2m1_idx,
|
||||
)
|
||||
|
||||
test = sys.argv[1] if len(sys.argv) > 1 else "round_rne"
|
||||
|
||||
if test == "round_rne":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Float32)
|
||||
r = round_rne_u0_8(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "abs_scaled":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Float32)
|
||||
r = abs_scaled_to_e2m1_idx(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "fp8_encode":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Float32)
|
||||
r = fp8_e4m3_from_float32(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "fp8_decode":
|
||||
@cute.kernel
|
||||
def k(inp: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
x = cute.arch.load(inp.iterator, cutlass.Int32)
|
||||
r = fp8_e4m3_to_float32(x)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
elif test == "e2m1_quant":
|
||||
@cute.kernel
|
||||
def k(val: cute.Tensor, scale: cute.Tensor, out: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if tidx == cutlass.Int32(0):
|
||||
v = cute.arch.load(val.iterator, cutlass.Float32)
|
||||
s = cute.arch.load(scale.iterator, cutlass.Float32)
|
||||
r = quantize_e2m1_nibble(v, s)
|
||||
cute.arch.store(out.iterator, r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Test: {test}")
|
||||
if test == "e2m1_quant":
|
||||
v = torch.tensor([1.5], dtype=torch.float32, device='cuda')
|
||||
s = torch.tensor([1.0], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
vc = cutlass_torch.from_dlpack(v).mark_layout_dynamic(leading_dim=0)
|
||||
sc = cutlass_torch.from_dlpack(s).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(k, vc, sc, oc)
|
||||
print("Running...")
|
||||
compiled(vc, sc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
elif test == "fp8_decode":
|
||||
x = torch.tensor([126], dtype=torch.int32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.float32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
else:
|
||||
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
|
||||
o = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
|
||||
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
|
||||
print("Compiling...")
|
||||
compiled = cute.compile(k, xc, oc)
|
||||
print("Running...")
|
||||
compiled(xc, oc)
|
||||
print(f"Result: {o.item()}")
|
||||
10
tests/unit/test_fp4_isolate_runner.py
Normal file
10
tests/unit/test_fp4_isolate_runner.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Run each function isolation test in a separate process."""
|
||||
import subprocess, sys
|
||||
for t in ["round_rne", "abs_scaled", "fp8_encode", "fp8_decode", "e2m1_quant"]:
|
||||
print(f"\n{'='*50}\n{t}\n{'='*50}")
|
||||
r = subprocess.run([sys.executable, "tests/unit/test_fp4_isolate.py", t],
|
||||
capture_output=True, text=True, timeout=120)
|
||||
print(r.stdout[-300:] if r.stdout else "")
|
||||
if r.stderr:
|
||||
print(f"ERR: ...{r.stderr[-200:]}")
|
||||
print(f"Exit: {r.returncode}")
|
||||
Reference in New Issue
Block a user