Files
nvfp4-megamoe-kernel/tests/unit/test_ptx_convert.py
biondizzle 1cbb3cf752 NVFP4-1.1: Replace threshold rounding with inline PTX cvt.rni/rz/rmi
- Add f32_to_i32_rni (cvt.rni.s32.f32) for round-to-nearest-even
- Add f32_to_i32_rz (cvt.rzi.s32.f32) for round-toward-zero
- Add f32_to_i32_rmi (cvt.rmi.s32.f32) for round-to-minus-infinity
- Replace round_rne_u0_8 and abs_scaled_to_e2m1_idx threshold hacks
  with proper PTX hardware rounding in fp8_e4m3_from_float32
- quantize_e2m1_nibble now uses f32_to_i32_rni + LUT logic for half_step
- Add test_ptx_convert.py for inline PTX conversion verification
- This is the CORRECT approach per NVFP4-1.1_INLINE_PTX_APPROACH.md option 1
2026-05-28 04:40:17 +00:00

107 lines
4.2 KiB
Python

"""Test: inline PTX f32→i32 conversion and FP4 quantization in CuTeDSL."""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
from dsv4.kernels.gemm.fp4_quant import (
f32_to_i32_rni,
f32_to_i32_rz,
f32_to_i32_rmi,
quantize_e2m1_nibble,
fp8_e4m3_from_float32,
)
@cute.kernel
def ptx_convert_test_kernel(
input_f32: cute.Tensor, # (N,) Float32 inputs
output_rni: cute.Tensor, # (N,) Int32 RNE results
output_rz: cute.Tensor, # (N,) Int32 RZ results
output_rmi: cute.Tensor, # (N,) Int32 RMI results
):
tidx, _, _ = cute.arch.thread_idx()
if tidx < cutlass.Int32(cute.shape(input_f32)[0]):
x = cute.arch.load(input_f32.iterator, cutlass.Float32)
cute.arch.store(output_rni.iterator, f32_to_i32_rni(x))
cute.arch.store(output_rz.iterator, f32_to_i32_rz(x))
cute.arch.store(output_rmi.iterator, f32_to_i32_rmi(x))
@cute.kernel
def fp4_quant_test_kernel(
input_f32: cute.Tensor, # (N,) Float32 values
scale_f32: cute.Tensor, # (1,) Float32 scale
output_nibble: cute.Tensor, # (N,) Int32 nibbles
):
tidx, _, _ = cute.arch.thread_idx()
if tidx < cutlass.Int32(cute.shape(input_f32)[0]):
val = cute.arch.load(input_f32.iterator, cutlass.Float32)
scale = cute.arch.load(scale_f32.iterator, cutlass.Float32)
nibble = quantize_e2m1_nibble(val, scale)
cute.arch.store(output_nibble.iterator, nibble)
if __name__ == "__main__":
# Test PTX conversions
N = 12
test_vals = torch.tensor(
[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, -1.7, 0.0, 2.0, 3.0],
dtype=torch.float32, device='cuda'
)
out_rni = torch.zeros(N, dtype=torch.int32, device='cuda')
out_rz = torch.zeros(N, dtype=torch.int32, device='cuda')
out_rmi = torch.zeros(N, dtype=torch.int32, device='cuda')
vc = cutlass_torch.from_dlpack(test_vals).mark_layout_dynamic(leading_dim=0)
rni_c = cutlass_torch.from_dlpack(out_rni).mark_layout_dynamic(leading_dim=0)
rz_c = cutlass_torch.from_dlpack(out_rz).mark_layout_dynamic(leading_dim=0)
rmi_c = cutlass_torch.from_dlpack(out_rmi).mark_layout_dynamic(leading_dim=0)
print("Compiling PTX conversion kernel...")
compiled = cute.compile(ptx_convert_test_kernel, vc, rni_c, rz_c, rmi_c)
print("Running...")
compiled(vc, rni_c, rz_c, rmi_c)
print("\nPTX conversion results:")
for i in range(N):
v = test_vals[i].item()
print(f" val={v:6.1f} rni={out_rni[i].item():3d} rz={out_rz[i].item():3d} rmi={out_rmi[i].item():3d}")
# Expected RNE: 0→0, 1.5→2, 2.5→2, 3.5→4, 4.5→4, 5.5→6, 6.5→6, 7.5→8, -1.7→-2, 0→0, 2→2, 3→3
expected_rni = [0, 2, 2, 4, 4, 6, 6, 8, -2, 0, 2, 3]
rni_ok = all(out_rni[i].item() == expected_rni[i] for i in range(N))
print(f"\nRNE correct: {rni_ok}")
# Test FP4 quantization
print("\n--- FP4 Quantization Test ---")
q_vals = torch.tensor([0.0, 0.25, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, -3.0],
dtype=torch.float32, device='cuda')
scale = torch.tensor([1.0], dtype=torch.float32, device='cuda')
q_out = torch.zeros(len(q_vals), dtype=torch.int32, device='cuda')
qvc = cutlass_torch.from_dlpack(q_vals).mark_layout_dynamic(leading_dim=0)
sc = cutlass_torch.from_dlpack(scale).mark_layout_dynamic(leading_dim=0)
qoc = cutlass_torch.from_dlpack(q_out).mark_layout_dynamic(leading_dim=0)
print("Compiling FP4 quant kernel...")
compiled_q = cute.compile(fp4_quant_test_kernel, qvc, sc, qoc)
print("Running...")
compiled_q(qvc, sc, qoc)
# E2M1 values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] → indices [0..7]
# sign bit: bit 3
e2m1_vals = [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
print("\nFP4 quantization results (scale=1.0):")
for i in range(len(q_vals)):
v = q_vals[i].item()
nib = q_out[i].item()
sign = (nib >> 3) & 1
idx = nib & 7
dequant = e2m1_vals[idx] if idx < 8 else -1
if sign:
dequant = -dequant
print(f" val={v:6.1f} nibble={nib:2d} sign={sign} idx={idx} dequant={dequant}")
print("\nDone!")