Add B2 score debug test

This commit is contained in:
2026-06-03 00:43:44 +00:00
parent afb82b9c89
commit 797345dfe9

View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python3
"""B2 debug: compare kernel scores with FP32 reference.
Prints actual score values to find where the kernel diverges.
"""
import sys
import math
import torch
import torch.nn.functional as F
def quantize_fp8_e4m3(x_fp32):
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
scale = amax / 448.0
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
return fp8.view(torch.uint8), scale.squeeze(-1)
def dequantize_fp8_e4m3(fp8_uint8, scale):
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
return fp8.float() * scale.unsqueeze(-1).float()
def main():
torch.manual_seed(42)
N_IH = 64; IHD = 128; N_COMP = 128; TOP_K = 128
q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5
k_fp32 = torch.randn(N_COMP, IHD, dtype=torch.float32) * 0.5
w_h = torch.randn(N_IH, dtype=torch.bfloat16).cuda() * 0.3
k_fp8, k_scale = quantize_fp8_e4m3(k_fp32)
k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda()
# --- FP32 reference ---
k_dequant = dequantize_fp8_e4m3(k_fp8.view(torch.uint8).cpu(), k_scale.cpu()).cuda()
scores_full = torch.einsum('nd,cd->nc', q_idx.float(), k_dequant.float()) # (64, 128)
scores_relu = F.relu(scores_full)
ref_scores = (scores_relu * w_h.unsqueeze(-1).float()).sum(0) # (128,)
print(f"Reference scores: [{ref_scores.min():.4f}, {ref_scores.max():.4f}] mean={ref_scores.mean():.4f}")
print(f" Positive scores: {(ref_scores > 0).sum().item()}/{N_COMP}")
print(f" First 8: {ref_scores[:8].tolist()}")
# Also check per-head scores
print(f"\nPer-head score stats:")
for h in [0, 1, 32, 63]:
head_scores = scores_full[h] # (128,)
print(f" H{h}: [{head_scores.min():.4f}, {head_scores.max():.4f}] "
f"relu_sum={scores_relu[h].sum():.4f} weighted_relu={(scores_relu[h]*w_h[h].float()).sum():.4f}")
# --- Run B2 kernel ---
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
])
topk_indices = torch.empty(TOP_K, dtype=torch.int32, device='cuda')
mod.indexer_fp8_score_topk(
q_idx, k_fp8, k_scale, w_h, topk_indices,
N_IH, IHD, TOP_K)
torch.cuda.synchronize()
print(f"\nKernel top-k indices: {topk_indices[:16].tolist()}")
print(f" Valid (>=0): {(topk_indices >= 0).sum().item()}")
print(f" Unique: {len(set(topk_indices.cpu().tolist()))}")
# Compare: kernel-selected scores vs reference
kernel_selected = topk_indices[topk_indices >= 0].cpu()
if len(kernel_selected) > 0:
ref_selected_scores = ref_scores[kernel_selected]
print(f" Kernel-selected ref scores: {ref_selected_scores[:8].tolist()}")
# Reference top-k
ref_topk = ref_scores.topk(TOP_K).indices
print(f" Reference top-k: {ref_topk[:8].tolist()}")
# Overlap
kernel_set = set(topk_indices.cpu().tolist()) - {-1}
ref_set = set(ref_topk.tolist())
overlap = len(kernel_set & ref_set)
print(f" Overlap: {overlap}/{len(ref_set)}")
# Check: are the FP8 Q scales correct?
# The kernel quantizes Q internally. Let's compute what it should use.
q_fp32 = q_idx.float()
q_amax = q_fp32.abs().amax(dim=-1) # (64,)
q_scales = q_amax / 448.0
q_scales = q_scales.clamp(min=1e-8)
print(f"\nQ per-row scales: [{q_scales.min():.6f}, {q_scales.max():.6f}] mean={q_scales.mean():.6f}")
sys.exit(0)
if __name__ == "__main__":
main()