59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Compute expected O for linear pattern P_ij = i*128 + j.
|
||
|
|
Use same random seed as test.
|
||
|
|
"""
|
||
|
|
import torch
|
||
|
|
import math
|
||
|
|
|
||
|
|
torch.manual_seed(42)
|
||
|
|
hd = 256
|
||
|
|
n_kv = 128
|
||
|
|
scale_softmax = 1.0 / math.sqrt(hd)
|
||
|
|
|
||
|
|
# Generate random Q,K,V as in test
|
||
|
|
q = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
||
|
|
k = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
||
|
|
v = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
||
|
|
|
||
|
|
# Compute reference P via softmax
|
||
|
|
scores = (q @ k.mT) * scale_softmax
|
||
|
|
p_ref = torch.softmax(scores, dim=-1) # 128x128
|
||
|
|
|
||
|
|
# Linear pattern P
|
||
|
|
p_linear = torch.zeros(128, 128, dtype=torch.bfloat16, device='cuda')
|
||
|
|
for i in range(128):
|
||
|
|
for j in range(128):
|
||
|
|
p_linear[i,j] = i*128 + j
|
||
|
|
|
||
|
|
# Compute O for both patterns
|
||
|
|
o_ref = p_ref @ v
|
||
|
|
o_linear = p_linear @ v
|
||
|
|
|
||
|
|
print("Reference O shape:", o_ref.shape)
|
||
|
|
print("Linear O shape:", o_linear.shape)
|
||
|
|
print("\nFirst row, first 4 cols:")
|
||
|
|
print("O_ref[0,:4] =", o_ref[0,:4,0].tolist())
|
||
|
|
print("O_lin[0,:4] =", o_linear[0,:4,0].tolist())
|
||
|
|
|
||
|
|
# Compute expected scaling if mapping correct
|
||
|
|
# Kernel output for hd=256: out[0,:4]=[0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875]
|
||
|
|
kernel_out = [0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875]
|
||
|
|
print("\nKernel out[0,:4] =", kernel_out)
|
||
|
|
|
||
|
|
# Compare with linear pattern
|
||
|
|
print("\nDifference kernel vs linear:")
|
||
|
|
for i in range(4):
|
||
|
|
diff = kernel_out[i] - o_linear[0,i,0].item()
|
||
|
|
print(f"col {i}: kernel {kernel_out[i]:.6f} vs linear {o_linear[0,i,0].item():.6f} diff={diff:.6f}")
|
||
|
|
|
||
|
|
# Compute cosine similarity
|
||
|
|
kernel_tensor = torch.tensor(kernel_out, dtype=torch.float32)
|
||
|
|
linear_tensor = o_linear[0,:4,0].float()
|
||
|
|
cos = torch.cosine_similarity(kernel_tensor, linear_tensor, dim=0).item()
|
||
|
|
print(f"\nCosine similarity (first 4 cols): {cos:.6f}")
|
||
|
|
|
||
|
|
# Also compute expected O for P=1.0 pattern
|
||
|
|
p_one = torch.ones(128,128, dtype=torch.bfloat16, device='cuda')
|
||
|
|
o_one = p_one @ v
|
||
|
|
print("\nP=1.0 pattern O[0,:4] =", o_one[0,:4,0].tolist())
|