Files
nvfp4-megamoe-kernel/tests/unit/test_fp4_isolate.py

99 lines
3.6 KiB
Python

"""Isolate which function causes the LLVM ERROR."""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import sys
from dsv4.kernels.gemm.fp4_quant import (
fp8_e4m3_from_float32,
fp8_e4m3_to_float32,
quantize_e2m1_nibble,
round_rne_u0_8,
abs_scaled_to_e2m1_idx,
)
test = sys.argv[1] if len(sys.argv) > 1 else "round_rne"
if test == "round_rne":
@cute.kernel
def k(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
x = cute.arch.load(inp.iterator, cutlass.Float32)
r = round_rne_u0_8(x)
cute.arch.store(out.iterator, r)
elif test == "abs_scaled":
@cute.kernel
def k(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
x = cute.arch.load(inp.iterator, cutlass.Float32)
r = abs_scaled_to_e2m1_idx(x)
cute.arch.store(out.iterator, r)
elif test == "fp8_encode":
@cute.kernel
def k(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
x = cute.arch.load(inp.iterator, cutlass.Float32)
r = fp8_e4m3_from_float32(x)
cute.arch.store(out.iterator, r)
elif test == "fp8_decode":
@cute.kernel
def k(inp: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
x = cute.arch.load(inp.iterator, cutlass.Int32)
r = fp8_e4m3_to_float32(x)
cute.arch.store(out.iterator, r)
elif test == "e2m1_quant":
@cute.kernel
def k(val: cute.Tensor, scale: cute.Tensor, out: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
v = cute.arch.load(val.iterator, cutlass.Float32)
s = cute.arch.load(scale.iterator, cutlass.Float32)
r = quantize_e2m1_nibble(v, s)
cute.arch.store(out.iterator, r)
if __name__ == "__main__":
print(f"Test: {test}")
if test == "e2m1_quant":
v = torch.tensor([1.5], dtype=torch.float32, device='cuda')
s = torch.tensor([1.0], dtype=torch.float32, device='cuda')
o = torch.zeros(1, dtype=torch.int32, device='cuda')
vc = cutlass_torch.from_dlpack(v).mark_layout_dynamic(leading_dim=0)
sc = cutlass_torch.from_dlpack(s).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
print("Compiling...")
compiled = cute.compile(k, vc, sc, oc)
print("Running...")
compiled(vc, sc, oc)
print(f"Result: {o.item()}")
elif test == "fp8_decode":
x = torch.tensor([126], dtype=torch.int32, device='cuda')
o = torch.zeros(1, dtype=torch.float32, device='cuda')
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
print("Compiling...")
compiled = cute.compile(k, xc, oc)
print("Running...")
compiled(xc, oc)
print(f"Result: {o.item()}")
else:
x = torch.tensor([3.7], dtype=torch.float32, device='cuda')
o = torch.zeros(1, dtype=torch.int32, device='cuda')
xc = cutlass_torch.from_dlpack(x).mark_layout_dynamic(leading_dim=0)
oc = cutlass_torch.from_dlpack(o).mark_layout_dynamic(leading_dim=0)
print("Compiling...")
compiled = cute.compile(k, xc, oc)
print("Running...")
compiled(xc, oc)
print(f"Result: {o.item()}")