diff --git a/debug_linear.py b/debug_linear.py new file mode 100644 index 00000000..d7aaf2ed --- /dev/null +++ b/debug_linear.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +""" +Compute expected O for linear pattern P_ij = i*128 + j. +Use same random seed as test. +""" +import torch +import math + +torch.manual_seed(42) +hd = 256 +n_kv = 128 +scale_softmax = 1.0 / math.sqrt(hd) + +# Generate random Q,K,V as in test +q = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda') +k = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda') +v = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda') + +# Compute reference P via softmax +scores = (q @ k.mT) * scale_softmax +p_ref = torch.softmax(scores, dim=-1) # 128x128 + +# Linear pattern P +p_linear = torch.zeros(128, 128, dtype=torch.bfloat16, device='cuda') +for i in range(128): + for j in range(128): + p_linear[i,j] = i*128 + j + +# Compute O for both patterns +o_ref = p_ref @ v +o_linear = p_linear @ v + +print("Reference O shape:", o_ref.shape) +print("Linear O shape:", o_linear.shape) +print("\nFirst row, first 4 cols:") +print("O_ref[0,:4] =", o_ref[0,:4,0].tolist()) +print("O_lin[0,:4] =", o_linear[0,:4,0].tolist()) + +# Compute expected scaling if mapping correct +# Kernel output for hd=256: out[0,:4]=[0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875] +kernel_out = [0.029296875, 0.0164794921875, -0.029541015625, 0.02294921875] +print("\nKernel out[0,:4] =", kernel_out) + +# Compare with linear pattern +print("\nDifference kernel vs linear:") +for i in range(4): + diff = kernel_out[i] - o_linear[0,i,0].item() + print(f"col {i}: kernel {kernel_out[i]:.6f} vs linear {o_linear[0,i,0].item():.6f} diff={diff:.6f}") + +# Compute cosine similarity +kernel_tensor = torch.tensor(kernel_out, dtype=torch.float32) +linear_tensor = o_linear[0,:4,0].float() +cos = torch.cosine_similarity(kernel_tensor, linear_tensor, dim=0).item() +print(f"\nCosine similarity (first 4 cols): {cos:.6f}") + +# Also compute expected O for P=1.0 pattern +p_one = torch.ones(128,128, dtype=torch.bfloat16, device='cuda') +o_one = p_one @ v +print("\nP=1.0 pattern O[0,:4] =", o_one[0,:4,0].tolist()) \ No newline at end of file