Files
nvfp4-megamoe-kernel/debug_linear.py

59 lines
1.9 KiB
Python
Raw Normal View History

2026-05-23 20:13:47 +00:00
#!/usr/bin/env python3
"""
Compute expected O for linear pattern P_ij = i*128 + j.
Use same random seed as test.
"""
import torch
import math
torch.manual_seed(42)
hd = 256
n_kv = 128
scale_softmax = 1.0 / math.sqrt(hd)
# Generate random Q,K,V as in test
q = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
# Compute reference P via softmax
scores = (q @ k.mT) * scale_softmax
p_ref = torch.softmax(scores, dim=-1) # 128x128
# Linear pattern P
p_linear = torch.zeros(128, 128, dtype=torch.bfloat16, device='cuda')
for i in range(128):
for j in range(128):
p_linear[i,j] = i*128 + j
# Compute O for both patterns
o_ref = p_ref @ v
o_linear = p_linear @ v
print("Reference O shape:", o_ref.shape)
print("Linear O shape:", o_linear.shape)
print("\nFirst row, first 4 cols:")
print("O_ref[0,:4] =", o_ref[0,:4,0].tolist())
print("O_lin[0,:4] =", o_linear[0,:4,0].tolist())
# Compute expected scaling if mapping correct
# Kernel output for hd=256: out[0,:4]=[0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875]
kernel_out = [0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875]
print("\nKernel out[0,:4] =", kernel_out)
# Compare with linear pattern
print("\nDifference kernel vs linear:")
for i in range(4):
diff = kernel_out[i] - o_linear[0,i,0].item()
print(f"col {i}: kernel {kernel_out[i]:.6f} vs linear {o_linear[0,i,0].item():.6f} diff={diff:.6f}")
# Compute cosine similarity
kernel_tensor = torch.tensor(kernel_out, dtype=torch.float32)
linear_tensor = o_linear[0,:4,0].float()
cos = torch.cosine_similarity(kernel_tensor, linear_tensor, dim=0).item()
print(f"\nCosine similarity (first 4 cols): {cos:.6f}")
# Also compute expected O for P=1.0 pattern
p_one = torch.ones(128,128, dtype=torch.bfloat16, device='cuda')
o_one = p_one @ v
print("\nP=1.0 pattern O[0,:4] =", o_one[0,:4,0].tolist())