- fp8_e4m3_from_float32: manual FP8 E4M3 cast (bias=7, exp 0-15 valid, NaN guard for exp=15/mant=7, mantissa overflow handling) - fp8_e4m3_to_float32: dequantize FP8 E4M3 bit pattern back to Float32 - half_step_to_e2m1_idx: E2M1 step mapping (0-12 → 0-7) - quantize_e2m1_nibble: per-element E2M1 quantize + sign + pack - Verified 0/500 trial failures against Python reference - Key fixes discovered during validation: 1. FP8 E4M3 bias is 7, NOT 8 2. Exponent range is 0-15 (exp=15/mant=7 is NaN; others valid) 3. Subnormal formula: val = m * 2^(-9) = m/512 (NOT m/1024) 4. Round-to-nearest-even (not round-half-up) for half_step and mantissa 5. Mantissa overflow (round to 8) must increment exponent
146 lines
6.1 KiB
Python
146 lines
6.1 KiB
Python
"""
|
|
NVFP4-1.1: Diagnostics for the SwiGLU epilogue register layout.
|
|
|
|
This kernel prints the mapping between register indices and output positions
|
|
for the epilogue subtiles. We need to understand this mapping to correctly
|
|
accumulate SwiGLU values across 2 up subtiles for FP4 quantization.
|
|
|
|
Key questions:
|
|
1. How many register elements per thread per subtile?
|
|
2. Which output positions does each thread own?
|
|
3. Do 2 consecutive up subtiles give 16 contiguous SwiGLU values per thread?
|
|
4. Are these 16 values the SAME 16 that form one NVFP4 microblock?
|
|
|
|
This test runs on B200 only (needs SM100 hardware).
|
|
"""
|
|
|
|
import torch
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
|
|
|
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
|
from dsv4.ops.gemm_runner import run_fused_swiglu_grouped_gemm, warmup_fused_swiglu_compilation
|
|
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
|
from dsv4.ops.layouts import (
|
|
make_b_k_major,
|
|
assemble_scales_3d_side,
|
|
interleave_l1_weights,
|
|
pad_and_swizzle_single,
|
|
)
|
|
|
|
|
|
def diagnose_epilogue_layout():
|
|
"""Print the epilogue register layout for understanding FP4 quantization.
|
|
|
|
We run a small fused SwiGLU GEMM and inspect the kernel's epilogue
|
|
configuration: epi_tile shape, number of subtiles, elements per thread.
|
|
"""
|
|
device = "cuda"
|
|
num_experts = 4
|
|
hidden = 256 # K (packed)
|
|
intermediate = 512 # N (packed) = 2 * intermediate_real
|
|
tokens = 32
|
|
|
|
# Create test inputs
|
|
mat_a = torch.randn(tokens, hidden, dtype=torch.float4_e2m1fn_x2, device=device)
|
|
mat_b = torch.randn(num_experts, hidden, intermediate, dtype=torch.float4_e2m1fn_x2, device=device)
|
|
scale_a = torch.randn(tokens, hidden // 16, dtype=torch.float8_e4m3fn, device=device)
|
|
scale_b = torch.randn(num_experts, intermediate, hidden // 16, dtype=torch.float8_e4m3fn, device=device)
|
|
expert_offsets = torch.tensor([8, 16, 24, 32], dtype=torch.int32, device=device)
|
|
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device) * 0.001
|
|
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device) * 0.001
|
|
|
|
# Create kernel to inspect epilogue config
|
|
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
|
|
|
kernel = FusedSwiGLUScaledGroupedGemmKernel(
|
|
scenario="2Dx3D",
|
|
sf_vec_size=16,
|
|
accumulate_on_output=False,
|
|
separate_tensormap_init=True,
|
|
consistent_token_padding=False,
|
|
mma_tiler_mnk=(128, 128, 256),
|
|
cluster_shape_mnk=(1, 1, 1),
|
|
fused_swiglu=True,
|
|
swiglu_limit=0.0,
|
|
)
|
|
|
|
print("=" * 60)
|
|
print("Epilogue Layout Diagnostics")
|
|
print("=" * 60)
|
|
print(f" epi_tile: {kernel.epi_tile}")
|
|
print(f" epi_tile_n: {kernel.epi_tile_n}")
|
|
print(f" cta_tile_shape_mnk: {kernel.cta_tile_shape_mnk}")
|
|
print(f" c_dtype: {kernel.c_dtype}")
|
|
print(f" epilogue_warp_id: {kernel.epilogue_warp_id}")
|
|
print(f" num_epilogue_threads: {32 * len(kernel.epilogue_warp_id)}")
|
|
|
|
# Compute elements per thread per subtile
|
|
epi_m = 128 # from cta_tile_shape_mnk[0]
|
|
epi_n = kernel.epi_tile_n # 8 for fused_swiglu
|
|
epi_elements = epi_m * epi_n # 128 * 8 = 1024 elements per subtile
|
|
epi_threads = 32 * len(kernel.epilogue_warp_id) # 128
|
|
elements_per_thread = epi_elements // epi_threads # 1024 / 128 = 8
|
|
num_subtiles = kernel.cta_tile_shape_mnk[1] // kernel.epi_tile_n # 128 / 8 = 16
|
|
num_gate_subtiles = num_subtiles // 2 # 8
|
|
num_up_subtiles = num_subtiles // 2 # 8
|
|
swiglu_per_cta = num_up_subtiles * elements_per_thread # 8 * 8 = 64
|
|
total_swiglu_per_cta = epi_m * (kernel.cta_tile_shape_mnk[1] // 2) # 128 * 64 = 8192
|
|
|
|
print(f"\n Elements per subtile: {epi_elements}")
|
|
print(f" Elements per thread per subtile: {elements_per_thread}")
|
|
print(f" Total subtiles per CTA tile: {num_subtiles}")
|
|
print(f" Gate subtiles: {num_gate_subtiles}")
|
|
print(f" Up subtiles: {num_up_subtiles}")
|
|
print(f" SwiGLU values per thread (all up subtiles): {swiglu_per_cta}")
|
|
print(f" Total SwiGLU values per CTA tile: {total_swiglu_per_cta}")
|
|
|
|
# NVFP4 microblocks
|
|
nvfp4_block_size = 16
|
|
swiglu_per_cta_total = epi_m * (kernel.cta_tile_shape_mnk[1] // 2) # 128 * 64 = 8192
|
|
num_nvfp4_blocks = swiglu_per_cta_total // nvfp4_block_size # 8192 / 16 = 512
|
|
|
|
print(f"\n NVFP4 microblocks per CTA tile: {num_nvfp4_blocks}")
|
|
print(f" SwiGLU values per thread: {swiglu_per_cta}")
|
|
print(f" NVFP4 microblocks per thread: {swiglu_per_cta // nvfp4_block_size * nvfp4_block_size}")
|
|
|
|
# The key question: can we pair 2 up subtiles (16 values per thread)
|
|
# to form one NVFP4 block?
|
|
print(f"\n Key: 2 up subtiles give {2 * elements_per_thread} SwiGLU values per thread")
|
|
print(f" NVFP4 block size: {nvfp4_block_size}")
|
|
print(f" Match: {2 * elements_per_thread == nvfp4_block_size}")
|
|
|
|
if 2 * elements_per_thread == nvfp4_block_size:
|
|
print("\n ✅ 2 up subtiles = 1 NVFP4 block per thread. Accumulation pattern works!")
|
|
else:
|
|
print(f"\n ❌ Mismatch: 2 up subtiles give {2 * elements_per_thread} values, need {nvfp4_block_size}")
|
|
|
|
# Run a small GEMM to verify
|
|
print("\n" + "=" * 60)
|
|
print("Running small fused SwiGLU GEMM to verify output layout...")
|
|
print("=" * 60)
|
|
|
|
# We need proper interleaved weights for the fused SwiGLU kernel
|
|
# For now, just verify the kernel runs
|
|
try:
|
|
l1_out = run_fused_swiglu_grouped_gemm(
|
|
mat_a=mat_a, mat_b=mat_b,
|
|
scale_a=scale_a, scale_b=scale_b,
|
|
expert_offsets=expert_offsets,
|
|
global_scale_a=global_scale_a, global_scale_b=global_scale_b,
|
|
)
|
|
print(f" L1 output shape: {l1_out.shape}")
|
|
print(f" L1 output dtype: {l1_out.dtype}")
|
|
print(f" L1 output (first row, first 16): {l1_out[0, :16].cpu()}")
|
|
except Exception as e:
|
|
print(f" Error running GEMM: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
diagnose_epilogue_layout()
|