diff --git a/tests/unit/test_b2_debug_qk.py b/tests/unit/test_b2_debug_qk.py new file mode 100644 index 00000000..47944d73 --- /dev/null +++ b/tests/unit/test_b2_debug_qk.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +"""B2 debug: read raw TMEM logits and compare with FP32 QK scores. + +Bypass the weighted ReLU and top-k — just check if the FP8 GEMM +produces correct logits after dequantization. +""" +import sys +import math +import torch +import torch.nn.functional as F + + +def quantize_fp8_e4m3(x_fp32): + amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = amax / 448.0 + fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn) + return fp8.view(torch.uint8), scale.squeeze(-1) + + +def dequantize_fp8_e4m3(fp8_uint8, scale): + fp8 = fp8_uint8.view(torch.float8_e4m3fn) + return fp8.float() * scale.unsqueeze(-1).float() + + +def main(): + torch.manual_seed(42) + N_IH = 64; IHD = 128; N_COMP = 8 # small for manual inspection + TOP_K = 8 + + q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5 + k_fp32 = torch.randn(N_COMP, IHD, dtype=torch.float32) * 0.5 + w_h = torch.ones(N_IH, dtype=torch.bfloat16).cuda() # all positive weights for simplicity + k_fp8, k_scale = quantize_fp8_e4m3(k_fp32) + k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda() + + # --- FP32 reference scores (full matrix) --- + k_dequant = dequantize_fp8_e4m3(k_fp8.view(torch.uint8).cpu(), k_scale.cpu()).cuda() + scores_ref = torch.einsum('nd,cd->nc', q_idx.float(), k_dequant.float()) # (64, 8) + + print("=== FP32 Reference Scores (h, c) ===") + for h in range(4): + print(f" H{h}: {scores_ref[h, :4].tolist()}") + + # --- Q per-row quantization (same as kernel) --- + q_fp32 = q_idx.float() + q_amax = q_fp32.abs().amax(dim=-1) + q_scales = (q_amax / 448.0).clamp(min=1e-8) + q_fp8_raw = (q_fp32 / q_scales.unsqueeze(-1)).clamp(-448, 448) + q_fp8 = q_fp8_raw.to(torch.float8_e4m3fn).view(torch.uint8) + + # FP8 dequant round-trip + q_dequant = dequantize_fp8_e4m3(q_fp8, q_scales) + cos_q = F.cosine_similarity(q_fp32.flatten().unsqueeze(0), q_dequant.flatten().unsqueeze(0)).item() + print(f"\nQ FP8 round-trip cosine: {cos_q:.6f}") + + # FP8 GEMM in Python: (q_dequant / q_scales) . (k_dequant / k_scales) + # = q_dequant @ k_dequant^T / (q_scales * k_scales) + # But wait — the MMA computes (q_fp8 @ k_fp8^T), and the result is + # raw logits that need to be dequantized: logit * q_scale[h] * k_scale[c] + # q_fp8 = q / q_scale, k_fp8 = k / k_scale + # MMA output = q_fp8 @ k_fp8^T = (q/q_scale) @ (k/k_scale)^T + # Dequantized = MMA_output * q_scale * k_scale = q @ k^T + + # Simulate: compute the raw MMA output as (q_fp8 values @ k_fp8 values) + q8_dequant_for_mma = q_dequant / q_scales.unsqueeze(-1) # "q_fp8" values (but in FP32) + k8_dequant_for_mma = k_dequant / k_scale.unsqueeze(-1) # "k_fp8" values (but in FP32) + mma_output = torch.einsum('nd,cd->nc', q8_dequant_for_mma, k8_dequant_for_mma) # (64, 8) + + # Dequantized logits = mma_output * q_scale * k_scale + dequant_logits = mma_output * q_scales.unsqueeze(-1) * k_scale.unsqueeze(0) + + print(f"\nSimulated dequant logits:") + for h in range(4): + print(f" H{h}: {dequant_logits[h, :4].tolist()}") + print(f" Ref: {scores_ref[h, :4].tolist()}") + + cos_logits = F.cosine_similarity(dequant_logits.flatten().unsqueeze(0), + scores_ref.flatten().unsqueeze(0)).item() + print(f" Cosine vs ref: {cos_logits:.6f}") + + # Now check: the kernel's MMA should produce similar raw logits. + # If we run the kernel and it gives different results, the MMA itself is wrong. + + # --- Run B2 kernel --- + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"], + extra_cuda_cflags=[ + "-gencode=arch=compute_100a,code=sm_100a", + "-O3", "--use_fast_math", "--expt-relaxed-constexpr", + ]) + + topk_indices = torch.empty(TOP_K, dtype=torch.int32, device='cuda') + mod.indexer_fp8_score_topk( + q_idx, k_fp8, k_scale, w_h, topk_indices, + N_IH, IHD, TOP_K) + torch.cuda.synchronize() + + print(f"\nKernel top-k (w_h=1): {topk_indices.tolist()}") + + # Reference top-k with w_h=1 + scores_relu = F.relu(scores_ref) + ref_total = scores_relu.sum(0) # sum over heads + ref_topk = ref_total.topk(TOP_K).indices + print(f"Reference top-k (w_h=1): {ref_topk.tolist()}") + print(f"Reference scores (w_h=1): {ref_total[ref_topk].tolist()}") + + # Check overlap + kernel_set = set(topk_indices.cpu().tolist()) - {-1} + ref_set = set(ref_topk.tolist()) + overlap = len(kernel_set & ref_set) + print(f"Overlap: {overlap}/{TOP_K}") + + sys.exit(0) + + +if __name__ == "__main__": + main()