auto: pre-test commit
This commit is contained in:
59
debug_linear.py
Normal file
59
debug_linear.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compute expected O for linear pattern P_ij = i*128 + j.
|
||||
Use same random seed as test.
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
|
||||
torch.manual_seed(42)
|
||||
hd = 256
|
||||
n_kv = 128
|
||||
scale_softmax = 1.0 / math.sqrt(hd)
|
||||
|
||||
# Generate random Q,K,V as in test
|
||||
q = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Compute reference P via softmax
|
||||
scores = (q @ k.mT) * scale_softmax
|
||||
p_ref = torch.softmax(scores, dim=-1) # 128x128
|
||||
|
||||
# Linear pattern P
|
||||
p_linear = torch.zeros(128, 128, dtype=torch.bfloat16, device='cuda')
|
||||
for i in range(128):
|
||||
for j in range(128):
|
||||
p_linear[i,j] = i*128 + j
|
||||
|
||||
# Compute O for both patterns
|
||||
o_ref = p_ref @ v
|
||||
o_linear = p_linear @ v
|
||||
|
||||
print("Reference O shape:", o_ref.shape)
|
||||
print("Linear O shape:", o_linear.shape)
|
||||
print("\nFirst row, first 4 cols:")
|
||||
print("O_ref[0,:4] =", o_ref[0,:4,0].tolist())
|
||||
print("O_lin[0,:4] =", o_linear[0,:4,0].tolist())
|
||||
|
||||
# Compute expected scaling if mapping correct
|
||||
# Kernel output for hd=256: out[0,:4]=[0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875]
|
||||
kernel_out = [0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875]
|
||||
print("\nKernel out[0,:4] =", kernel_out)
|
||||
|
||||
# Compare with linear pattern
|
||||
print("\nDifference kernel vs linear:")
|
||||
for i in range(4):
|
||||
diff = kernel_out[i] - o_linear[0,i,0].item()
|
||||
print(f"col {i}: kernel {kernel_out[i]:.6f} vs linear {o_linear[0,i,0].item():.6f} diff={diff:.6f}")
|
||||
|
||||
# Compute cosine similarity
|
||||
kernel_tensor = torch.tensor(kernel_out, dtype=torch.float32)
|
||||
linear_tensor = o_linear[0,:4,0].float()
|
||||
cos = torch.cosine_similarity(kernel_tensor, linear_tensor, dim=0).item()
|
||||
print(f"\nCosine similarity (first 4 cols): {cos:.6f}")
|
||||
|
||||
# Also compute expected O for P=1.0 pattern
|
||||
p_one = torch.ones(128,128, dtype=torch.bfloat16, device='cuda')
|
||||
o_one = p_one @ v
|
||||
print("\nP=1.0 pattern O[0,:4] =", o_one[0,:4,0].tolist())
|
||||
Reference in New Issue
Block a user