diff --git a/tests/unit/test_fp4_isolate.py b/tests/unit/test_fp4_isolate.py new file mode 100644 index 00000000..ea4134a4 --- /dev/null +++ b/tests/unit/test_fp4_isolate.py @@ -0,0 +1,98 @@ +"""Isolate which function causes the LLVM ERROR.""" +import torch +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +import sys + +from dsv4.kernels.gemm.fp4_quant import ( + fp8_e4m3_from_float32, + fp8_e4m3_to_float32, + quantize_e2m1_nibble, + round_rne_u0_8, + abs_scaled_to_e2m1_idx, +) + +test = sys.argv[1] if len(sys.argv) > 1 else "round_rne" + +if test == "round_rne": + @cute.kernel + def k(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(inp.iterator, cutlass.Float32) + r = round_rne_u0_8(x) + cute.arch.store(out.iterator, r) + +elif test == "abs_scaled": + @cute.kernel + def k(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(inp.iterator, cutlass.Float32) + r = abs_scaled_to_e2m1_idx(x) + cute.arch.store(out.iterator, r) + +elif test == "fp8_encode": + @cute.kernel + def k(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(inp.iterator, cutlass.Float32) + r = fp8_e4m3_from_float32(x) + cute.arch.store(out.iterator, r) + +elif test == "fp8_decode": + @cute.kernel + def k(inp: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + x = cute.arch.load(inp.iterator, cutlass.Int32) + r = fp8_e4m3_to_float32(x) + cute.arch.store(out.iterator, r) + +elif test == "e2m1_quant": + @cute.kernel + def k(val: cute.Tensor, scale: cute.Tensor, out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx == cutlass.Int32(0): + v = cute.arch.load(val.iterator, cutlass.Float32) + s = cute.arch.load(scale.iterator, cutlass.Float32) + r = quantize_e2m1_nibble(v, s) + cute.arch.store(out.iterator, r) + + +if __name__ == "__main__": + print(f"Test: {test}") + if test == "e2m1_quant": + v = torch.tensor([1.5], dtype=torch.float32, device='cuda') + s = torch.tensor([1.0], dtype=torch.float32, device='cuda') + o = torch.zeros(1, dtype=torch.int32, device='cuda') + vc = cutlass_torch.from_dlpack(v).mark_layout_dynamic(leading_dim=0) + sc = cutlass_torch.from_dlpack(s).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0) + print("Compiling...") + compiled = cute.compile(k, vc, sc, oc) + print("Running...") + compiled(vc, sc, oc) + print(f"Result: {o.item()}") + elif test == "fp8_decode": + x = torch.tensor([126], dtype=torch.int32, device='cuda') + o = torch.zeros(1, dtype=torch.float32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0) + print("Compiling...") + compiled = cute.compile(k, xc, oc) + print("Running...") + compiled(xc, oc) + print(f"Result: {o.item()}") + else: + x = torch.tensor([3.7], dtype=torch.float32, device='cuda') + o = torch.zeros(1, dtype=torch.int32, device='cuda') + xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0) + oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0) + print("Compiling...") + compiled = cute.compile(k, xc, oc) + print("Running...") + compiled(xc, oc) + print(f"Result: {o.item()}") diff --git a/tests/unit/test_fp4_isolate_runner.py b/tests/unit/test_fp4_isolate_runner.py new file mode 100644 index 00000000..1aa3c1e1 --- /dev/null +++ b/tests/unit/test_fp4_isolate_runner.py @@ -0,0 +1,10 @@ +"""Run each function isolation test in a separate process.""" +import subprocess, sys +for t in ["round_rne", "abs_scaled", "fp8_encode", "fp8_decode", "e2m1_quant"]: + print(f"\n{'='*50}\n{t}\n{'='*50}") + r = subprocess.run([sys.executable, "tests/unit/test_fp4_isolate.py", t], + capture_output=True, text=True, timeout=120) + print(r.stdout[-300:] if r.stdout else "") + if r.stderr: + print(f"ERR: ...{r.stderr[-200:]}") + print(f"Exit: {r.returncode}")