Files
nvfp4-megamoe-kernel/tests/unit/test_b2_debug_qk.py

118 lines
4.6 KiB
Python

#!/usr/bin/env python3
"""B2 debug: read raw TMEM logits and compare with FP32 QK scores.
Bypass the weighted ReLU and top-k — just check if the FP8 GEMM
produces correct logits after dequantization.
"""
import sys
import math
import torch
import torch.nn.functional as F
def quantize_fp8_e4m3(x_fp32):
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
scale = amax / 448.0
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
return fp8.view(torch.uint8), scale.squeeze(-1)
def dequantize_fp8_e4m3(fp8_uint8, scale):
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
return fp8.float() * scale.unsqueeze(-1).float()
def main():
torch.manual_seed(42)
N_IH = 64; IHD = 128; N_COMP = 8 # small for manual inspection
TOP_K = 8
q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5
k_fp32 = torch.randn(N_COMP, IHD, dtype=torch.float32) * 0.5
w_h = torch.ones(N_IH, dtype=torch.bfloat16).cuda() # all positive weights for simplicity
k_fp8, k_scale = quantize_fp8_e4m3(k_fp32)
k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda()
# --- FP32 reference scores (full matrix) ---
k_dequant = dequantize_fp8_e4m3(k_fp8.view(torch.uint8).cpu(), k_scale.cpu()).cuda()
scores_ref = torch.einsum('nd,cd->nc', q_idx.float(), k_dequant.float()) # (64, 8)
print("=== FP32 Reference Scores (h, c) ===")
for h in range(4):
print(f" H{h}: {scores_ref[h, :4].tolist()}")
# --- Q per-row quantization (same as kernel) ---
q_fp32 = q_idx.float()
q_amax = q_fp32.abs().amax(dim=-1)
q_scales = (q_amax / 448.0).clamp(min=1e-8)
q_fp8_raw = (q_fp32 / q_scales.unsqueeze(-1)).clamp(-448, 448)
q_fp8 = q_fp8_raw.to(torch.float8_e4m3fn).view(torch.uint8)
# FP8 dequant round-trip
q_dequant = dequantize_fp8_e4m3(q_fp8, q_scales)
cos_q = F.cosine_similarity(q_fp32.flatten().unsqueeze(0), q_dequant.flatten().unsqueeze(0)).item()
print(f"\nQ FP8 round-trip cosine: {cos_q:.6f}")
# FP8 GEMM in Python: (q_dequant / q_scales) . (k_dequant / k_scales)
# = q_dequant @ k_dequant^T / (q_scales * k_scales)
# But wait — the MMA computes (q_fp8 @ k_fp8^T), and the result is
# raw logits that need to be dequantized: logit * q_scale[h] * k_scale[c]
# q_fp8 = q / q_scale, k_fp8 = k / k_scale
# MMA output = q_fp8 @ k_fp8^T = (q/q_scale) @ (k/k_scale)^T
# Dequantized = MMA_output * q_scale * k_scale = q @ k^T
# Simulate: compute the raw MMA output as (q_fp8 values @ k_fp8 values)
q8_dequant_for_mma = q_dequant / q_scales.unsqueeze(-1) # "q_fp8" values (but in FP32)
k8_dequant_for_mma = k_dequant / k_scale.unsqueeze(-1) # "k_fp8" values (but in FP32)
mma_output = torch.einsum('nd,cd->nc', q8_dequant_for_mma, k8_dequant_for_mma) # (64, 8)
# Dequantized logits = mma_output * q_scale * k_scale
dequant_logits = mma_output * q_scales.unsqueeze(-1) * k_scale.unsqueeze(0)
print(f"\nSimulated dequant logits:")
for h in range(4):
print(f" H{h}: {dequant_logits[h, :4].tolist()}")
print(f" Ref: {scores_ref[h, :4].tolist()}")
cos_logits = F.cosine_similarity(dequant_logits.flatten().unsqueeze(0),
scores_ref.flatten().unsqueeze(0)).item()
print(f" Cosine vs ref: {cos_logits:.6f}")
# Now check: the kernel's MMA should produce similar raw logits.
# If we run the kernel and it gives different results, the MMA itself is wrong.
# --- Run B2 kernel ---
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
])
topk_indices = torch.empty(TOP_K, dtype=torch.int32, device='cuda')
mod.indexer_fp8_score_topk(
q_idx, k_fp8, k_scale, w_h, topk_indices,
N_IH, IHD, TOP_K)
torch.cuda.synchronize()
print(f"\nKernel top-k (w_h=1): {topk_indices.tolist()}")
# Reference top-k with w_h=1
scores_relu = F.relu(scores_ref)
ref_total = scores_relu.sum(0) # sum over heads
ref_topk = ref_total.topk(TOP_K).indices
print(f"Reference top-k (w_h=1): {ref_topk.tolist()}")
print(f"Reference scores (w_h=1): {ref_total[ref_topk].tolist()}")
# Check overlap
kernel_set = set(topk_indices.cpu().tolist()) - {-1}
ref_set = set(ref_topk.tolist())
overlap = len(kernel_set & ref_set)
print(f"Overlap: {overlap}/{TOP_K}")
sys.exit(0)
if __name__ == "__main__":
main()