Files
nvfp4-megamoe-kernel/tests/unit/test_nvfp4_1_1_layout.py
biondizzle 80b6b79f9e NVFP4-1.1: FP4 quantization primitives for CuTeDSL kernels
- fp8_e4m3_from_float32: manual FP8 E4M3 cast (bias=7, exp 0-15 valid,
  NaN guard for exp=15/mant=7, mantissa overflow handling)
- fp8_e4m3_to_float32: dequantize FP8 E4M3 bit pattern back to Float32
- half_step_to_e2m1_idx: E2M1 step mapping (0-12 → 0-7)
- quantize_e2m1_nibble: per-element E2M1 quantize + sign + pack
- Verified 0/500 trial failures against Python reference
- Key fixes discovered during validation:
  1. FP8 E4M3 bias is 7, NOT 8
  2. Exponent range is 0-15 (exp=15/mant=7 is NaN; others valid)
  3. Subnormal formula: val = m * 2^(-9) = m/512 (NOT m/1024)
  4. Round-to-nearest-even (not round-half-up) for half_step and mantissa
  5. Mantissa overflow (round to 8) must increment exponent
2026-05-28 03:39:55 +00:00

146 lines
6.1 KiB
Python

"""
NVFP4-1.1: Diagnostics for the SwiGLU epilogue register layout.
This kernel prints the mapping between register indices and output positions
for the epilogue subtiles. We need to understand this mapping to correctly
accumulate SwiGLU values across 2 up subtiles for FP4 quantization.
Key questions:
1. How many register elements per thread per subtile?
2. Which output positions does each thread own?
3. Do 2 consecutive up subtiles give 16 contiguous SwiGLU values per thread?
4. Are these 16 values the SAME 16 that form one NVFP4 microblock?
This test runs on B200 only (needs SM100 hardware).
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
from dsv4.ops.gemm_runner import run_fused_swiglu_grouped_gemm, warmup_fused_swiglu_compilation
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_3d_side,
interleave_l1_weights,
pad_and_swizzle_single,
)
def diagnose_epilogue_layout():
"""Print the epilogue register layout for understanding FP4 quantization.
We run a small fused SwiGLU GEMM and inspect the kernel's epilogue
configuration: epi_tile shape, number of subtiles, elements per thread.
"""
device = "cuda"
num_experts = 4
hidden = 256 # K (packed)
intermediate = 512 # N (packed) = 2 * intermediate_real
tokens = 32
# Create test inputs
mat_a = torch.randn(tokens, hidden, dtype=torch.float4_e2m1fn_x2, device=device)
mat_b = torch.randn(num_experts, hidden, intermediate, dtype=torch.float4_e2m1fn_x2, device=device)
scale_a = torch.randn(tokens, hidden // 16, dtype=torch.float8_e4m3fn, device=device)
scale_b = torch.randn(num_experts, intermediate, hidden // 16, dtype=torch.float8_e4m3fn, device=device)
expert_offsets = torch.tensor([8, 16, 24, 32], dtype=torch.int32, device=device)
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device) * 0.001
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device) * 0.001
# Create kernel to inspect epilogue config
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
kernel = FusedSwiGLUScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=16,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(128, 128, 256),
cluster_shape_mnk=(1, 1, 1),
fused_swiglu=True,
swiglu_limit=0.0,
)
print("=" * 60)
print("Epilogue Layout Diagnostics")
print("=" * 60)
print(f" epi_tile: {kernel.epi_tile}")
print(f" epi_tile_n: {kernel.epi_tile_n}")
print(f" cta_tile_shape_mnk: {kernel.cta_tile_shape_mnk}")
print(f" c_dtype: {kernel.c_dtype}")
print(f" epilogue_warp_id: {kernel.epilogue_warp_id}")
print(f" num_epilogue_threads: {32 * len(kernel.epilogue_warp_id)}")
# Compute elements per thread per subtile
epi_m = 128 # from cta_tile_shape_mnk[0]
epi_n = kernel.epi_tile_n # 8 for fused_swiglu
epi_elements = epi_m * epi_n # 128 * 8 = 1024 elements per subtile
epi_threads = 32 * len(kernel.epilogue_warp_id) # 128
elements_per_thread = epi_elements // epi_threads # 1024 / 128 = 8
num_subtiles = kernel.cta_tile_shape_mnk[1] // kernel.epi_tile_n # 128 / 8 = 16
num_gate_subtiles = num_subtiles // 2 # 8
num_up_subtiles = num_subtiles // 2 # 8
swiglu_per_cta = num_up_subtiles * elements_per_thread # 8 * 8 = 64
total_swiglu_per_cta = epi_m * (kernel.cta_tile_shape_mnk[1] // 2) # 128 * 64 = 8192
print(f"\n Elements per subtile: {epi_elements}")
print(f" Elements per thread per subtile: {elements_per_thread}")
print(f" Total subtiles per CTA tile: {num_subtiles}")
print(f" Gate subtiles: {num_gate_subtiles}")
print(f" Up subtiles: {num_up_subtiles}")
print(f" SwiGLU values per thread (all up subtiles): {swiglu_per_cta}")
print(f" Total SwiGLU values per CTA tile: {total_swiglu_per_cta}")
# NVFP4 microblocks
nvfp4_block_size = 16
swiglu_per_cta_total = epi_m * (kernel.cta_tile_shape_mnk[1] // 2) # 128 * 64 = 8192
num_nvfp4_blocks = swiglu_per_cta_total // nvfp4_block_size # 8192 / 16 = 512
print(f"\n NVFP4 microblocks per CTA tile: {num_nvfp4_blocks}")
print(f" SwiGLU values per thread: {swiglu_per_cta}")
print(f" NVFP4 microblocks per thread: {swiglu_per_cta // nvfp4_block_size * nvfp4_block_size}")
# The key question: can we pair 2 up subtiles (16 values per thread)
# to form one NVFP4 block?
print(f"\n Key: 2 up subtiles give {2 * elements_per_thread} SwiGLU values per thread")
print(f" NVFP4 block size: {nvfp4_block_size}")
print(f" Match: {2 * elements_per_thread == nvfp4_block_size}")
if 2 * elements_per_thread == nvfp4_block_size:
print("\n ✅ 2 up subtiles = 1 NVFP4 block per thread. Accumulation pattern works!")
else:
print(f"\n ❌ Mismatch: 2 up subtiles give {2 * elements_per_thread} values, need {nvfp4_block_size}")
# Run a small GEMM to verify
print("\n" + "=" * 60)
print("Running small fused SwiGLU GEMM to verify output layout...")
print("=" * 60)
# We need proper interleaved weights for the fused SwiGLU kernel
# For now, just verify the kernel runs
try:
l1_out = run_fused_swiglu_grouped_gemm(
mat_a=mat_a, mat_b=mat_b,
scale_a=scale_a, scale_b=scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a, global_scale_b=global_scale_b,
)
print(f" L1 output shape: {l1_out.shape}")
print(f" L1 output dtype: {l1_out.dtype}")
print(f" L1 output (first row, first 16): {l1_out[0, :16].cpu()}")
except Exception as e:
print(f" Error running GEMM: {e}")
if __name__ == "__main__":
diagnose_epilogue_layout()