Fix causal mask shape for SDPA: (1,1,T,T) broadcast
This commit is contained in:
@@ -409,8 +409,8 @@ def full_attention_reference(
|
||||
# Q: (T, NH, HD) → (T*NH, 1, HD)
|
||||
q_2d = q.reshape(T * NH, 1, HD)
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)).unsqueeze(0)
|
||||
# Causal mask: (1, 1, T, T) broadcast over batch dim T*NH
|
||||
causal_mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_2d, k_2d, v_2d,
|
||||
|
||||
Reference in New Issue
Block a user