diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 6458df0f..115f0345 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -375,7 +375,10 @@ class FmhaKernel: # softmax warps are 128 threads (4 warps × 32) # Compute which softmax thread this is (0-127) - softmax_thread_idx = sfw_idx * 32 + (thread_idx % 32) + # thread_idx is (x,y,z) tuple, take x component + thread_x = thread_idx[0] + # sfw_idx is warp index within softmax group (0-3) + softmax_thread_idx = sfw_idx * 32 + (thread_x % 32) print(f"[SMEM-P MANUAL] Softmax thread idx: {softmax_thread_idx}") # Each thread handles 128 P values starting at index * 128