From 1e675ccc9a7bbca50759ee1da093123dc9aed4ee Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 08:00:39 +0000 Subject: [PATCH] Fix causal mask shape for SDPA: (1,1,T,T) broadcast --- cutedsl/csa_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cutedsl/csa_attention.py b/cutedsl/csa_attention.py index 53bbaa02..7173a1af 100644 --- a/cutedsl/csa_attention.py +++ b/cutedsl/csa_attention.py @@ -409,8 +409,8 @@ def full_attention_reference( # Q: (T, NH, HD) → (T*NH, 1, HD) q_2d = q.reshape(T * NH, 1, HD) - # Causal mask - causal_mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)).unsqueeze(0) + # Causal mask: (1, 1, T, T) broadcast over batch dim T*NH + causal_mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)).unsqueeze(0).unsqueeze(0) out = F.scaled_dot_product_attention( q_2d, k_2d, v_2d,