118 lines
4.6 KiB
Python
118 lines
4.6 KiB
Python
#!/usr/bin/env python3
|
|
"""B2 debug: read raw TMEM logits and compare with FP32 QK scores.
|
|
|
|
Bypass the weighted ReLU and top-k — just check if the FP8 GEMM
|
|
produces correct logits after dequantization.
|
|
"""
|
|
import sys
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def quantize_fp8_e4m3(x_fp32):
|
|
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
scale = amax / 448.0
|
|
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
|
return fp8.view(torch.uint8), scale.squeeze(-1)
|
|
|
|
|
|
def dequantize_fp8_e4m3(fp8_uint8, scale):
|
|
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
|
|
return fp8.float() * scale.unsqueeze(-1).float()
|
|
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
N_IH = 64; IHD = 128; N_COMP = 8 # small for manual inspection
|
|
TOP_K = 8
|
|
|
|
q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5
|
|
k_fp32 = torch.randn(N_COMP, IHD, dtype=torch.float32) * 0.5
|
|
w_h = torch.ones(N_IH, dtype=torch.bfloat16).cuda() # all positive weights for simplicity
|
|
k_fp8, k_scale = quantize_fp8_e4m3(k_fp32)
|
|
k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda()
|
|
|
|
# --- FP32 reference scores (full matrix) ---
|
|
k_dequant = dequantize_fp8_e4m3(k_fp8.view(torch.uint8).cpu(), k_scale.cpu()).cuda()
|
|
scores_ref = torch.einsum('nd,cd->nc', q_idx.float(), k_dequant.float()) # (64, 8)
|
|
|
|
print("=== FP32 Reference Scores (h, c) ===")
|
|
for h in range(4):
|
|
print(f" H{h}: {scores_ref[h, :4].tolist()}")
|
|
|
|
# --- Q per-row quantization (same as kernel) ---
|
|
q_fp32 = q_idx.float()
|
|
q_amax = q_fp32.abs().amax(dim=-1)
|
|
q_scales = (q_amax / 448.0).clamp(min=1e-8)
|
|
q_fp8_raw = (q_fp32 / q_scales.unsqueeze(-1)).clamp(-448, 448)
|
|
q_fp8 = q_fp8_raw.to(torch.float8_e4m3fn).view(torch.uint8)
|
|
|
|
# FP8 dequant round-trip
|
|
q_dequant = dequantize_fp8_e4m3(q_fp8, q_scales)
|
|
cos_q = F.cosine_similarity(q_fp32.flatten().unsqueeze(0), q_dequant.flatten().unsqueeze(0)).item()
|
|
print(f"\nQ FP8 round-trip cosine: {cos_q:.6f}")
|
|
|
|
# FP8 GEMM in Python: (q_dequant / q_scales) . (k_dequant / k_scales)
|
|
# = q_dequant @ k_dequant^T / (q_scales * k_scales)
|
|
# But wait — the MMA computes (q_fp8 @ k_fp8^T), and the result is
|
|
# raw logits that need to be dequantized: logit * q_scale[h] * k_scale[c]
|
|
# q_fp8 = q / q_scale, k_fp8 = k / k_scale
|
|
# MMA output = q_fp8 @ k_fp8^T = (q/q_scale) @ (k/k_scale)^T
|
|
# Dequantized = MMA_output * q_scale * k_scale = q @ k^T
|
|
|
|
# Simulate: compute the raw MMA output as (q_fp8 values @ k_fp8 values)
|
|
q8_dequant_for_mma = q_dequant / q_scales.unsqueeze(-1) # "q_fp8" values (but in FP32)
|
|
k8_dequant_for_mma = k_dequant / k_scale.unsqueeze(-1) # "k_fp8" values (but in FP32)
|
|
mma_output = torch.einsum('nd,cd->nc', q8_dequant_for_mma, k8_dequant_for_mma) # (64, 8)
|
|
|
|
# Dequantized logits = mma_output * q_scale * k_scale
|
|
dequant_logits = mma_output * q_scales.unsqueeze(-1) * k_scale.unsqueeze(0)
|
|
|
|
print(f"\nSimulated dequant logits:")
|
|
for h in range(4):
|
|
print(f" H{h}: {dequant_logits[h, :4].tolist()}")
|
|
print(f" Ref: {scores_ref[h, :4].tolist()}")
|
|
|
|
cos_logits = F.cosine_similarity(dequant_logits.flatten().unsqueeze(0),
|
|
scores_ref.flatten().unsqueeze(0)).item()
|
|
print(f" Cosine vs ref: {cos_logits:.6f}")
|
|
|
|
# Now check: the kernel's MMA should produce similar raw logits.
|
|
# If we run the kernel and it gives different results, the MMA itself is wrong.
|
|
|
|
# --- Run B2 kernel ---
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
|
|
extra_cuda_cflags=[
|
|
"-gencode=arch=compute_100a,code=sm_100a",
|
|
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
|
])
|
|
|
|
topk_indices = torch.empty(TOP_K, dtype=torch.int32, device='cuda')
|
|
mod.indexer_fp8_score_topk(
|
|
q_idx, k_fp8, k_scale, w_h, topk_indices,
|
|
N_IH, IHD, TOP_K)
|
|
torch.cuda.synchronize()
|
|
|
|
print(f"\nKernel top-k (w_h=1): {topk_indices.tolist()}")
|
|
|
|
# Reference top-k with w_h=1
|
|
scores_relu = F.relu(scores_ref)
|
|
ref_total = scores_relu.sum(0) # sum over heads
|
|
ref_topk = ref_total.topk(TOP_K).indices
|
|
print(f"Reference top-k (w_h=1): {ref_topk.tolist()}")
|
|
print(f"Reference scores (w_h=1): {ref_total[ref_topk].tolist()}")
|
|
|
|
# Check overlap
|
|
kernel_set = set(topk_indices.cpu().tolist()) - {-1}
|
|
ref_set = set(ref_topk.tolist())
|
|
overlap = len(kernel_set & ref_set)
|
|
print(f"Overlap: {overlap}/{TOP_K}")
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|