diff --git a/cutedsl/csa_attention.py b/cutedsl/csa_attention.py index 53bbaa02..7173a1af 100644 --- a/cutedsl/csa_attention.py +++ b/cutedsl/csa_attention.py @@ -409,8 +409,8 @@ def full_attention_reference( # Q: (T, NH, HD) → (T*NH, 1, HD) q_2d = q.reshape(T * NH, 1, HD) - # Causal mask - causal_mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)).unsqueeze(0) + # Causal mask: (1, 1, T, T) broadcast over batch dim T*NH + causal_mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)).unsqueeze(0).unsqueeze(0) out = F.scaled_dot_product_attention( q_2d, k_2d, v_2d,