diff --git a/tests/unit/test_b2_debug_scores.py b/tests/unit/test_b2_debug_scores.py new file mode 100644 index 00000000..79a31723 --- /dev/null +++ b/tests/unit/test_b2_debug_scores.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""B2 debug: compare kernel scores with FP32 reference. + +Prints actual score values to find where the kernel diverges. +""" +import sys +import math +import torch +import torch.nn.functional as F + + +def quantize_fp8_e4m3(x_fp32): + amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = amax / 448.0 + fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn) + return fp8.view(torch.uint8), scale.squeeze(-1) + + +def dequantize_fp8_e4m3(fp8_uint8, scale): + fp8 = fp8_uint8.view(torch.float8_e4m3fn) + return fp8.float() * scale.unsqueeze(-1).float() + + +def main(): + torch.manual_seed(42) + N_IH = 64; IHD = 128; N_COMP = 128; TOP_K = 128 + + q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5 + k_fp32 = torch.randn(N_COMP, IHD, dtype=torch.float32) * 0.5 + w_h = torch.randn(N_IH, dtype=torch.bfloat16).cuda() * 0.3 + k_fp8, k_scale = quantize_fp8_e4m3(k_fp32) + k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda() + + # --- FP32 reference --- + k_dequant = dequantize_fp8_e4m3(k_fp8.view(torch.uint8).cpu(), k_scale.cpu()).cuda() + scores_full = torch.einsum('nd,cd->nc', q_idx.float(), k_dequant.float()) # (64, 128) + scores_relu = F.relu(scores_full) + ref_scores = (scores_relu * w_h.unsqueeze(-1).float()).sum(0) # (128,) + + print(f"Reference scores: [{ref_scores.min():.4f}, {ref_scores.max():.4f}] mean={ref_scores.mean():.4f}") + print(f" Positive scores: {(ref_scores > 0).sum().item()}/{N_COMP}") + print(f" First 8: {ref_scores[:8].tolist()}") + + # Also check per-head scores + print(f"\nPer-head score stats:") + for h in [0, 1, 32, 63]: + head_scores = scores_full[h] # (128,) + print(f" H{h}: [{head_scores.min():.4f}, {head_scores.max():.4f}] " + f"relu_sum={scores_relu[h].sum():.4f} weighted_relu={(scores_relu[h]*w_h[h].float()).sum():.4f}") + + # --- Run B2 kernel --- + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"], + extra_cuda_cflags=[ + "-gencode=arch=compute_100a,code=sm_100a", + "-O3", "--use_fast_math", "--expt-relaxed-constexpr", + ]) + + topk_indices = torch.empty(TOP_K, dtype=torch.int32, device='cuda') + mod.indexer_fp8_score_topk( + q_idx, k_fp8, k_scale, w_h, topk_indices, + N_IH, IHD, TOP_K) + torch.cuda.synchronize() + + print(f"\nKernel top-k indices: {topk_indices[:16].tolist()}") + print(f" Valid (>=0): {(topk_indices >= 0).sum().item()}") + print(f" Unique: {len(set(topk_indices.cpu().tolist()))}") + + # Compare: kernel-selected scores vs reference + kernel_selected = topk_indices[topk_indices >= 0].cpu() + if len(kernel_selected) > 0: + ref_selected_scores = ref_scores[kernel_selected] + print(f" Kernel-selected ref scores: {ref_selected_scores[:8].tolist()}") + + # Reference top-k + ref_topk = ref_scores.topk(TOP_K).indices + print(f" Reference top-k: {ref_topk[:8].tolist()}") + + # Overlap + kernel_set = set(topk_indices.cpu().tolist()) - {-1} + ref_set = set(ref_topk.tolist()) + overlap = len(kernel_set & ref_set) + print(f" Overlap: {overlap}/{len(ref_set)}") + + # Check: are the FP8 Q scales correct? + # The kernel quantizes Q internally. Let's compute what it should use. + q_fp32 = q_idx.float() + q_amax = q_fp32.abs().amax(dim=-1) # (64,) + q_scales = q_amax / 448.0 + q_scales = q_scales.clamp(min=1e-8) + print(f"\nQ per-row scales: [{q_scales.min():.6f}, {q_scales.max():.6f}] mean={q_scales.mean():.6f}") + + sys.exit(0) + + +if __name__ == "__main__": + main()