SMEM-P: fix thread_idx tuple access

This commit is contained in:
2026-05-23 19:30:09 +00:00
parent 17b91ab3d3
commit 59b086451c

View File

@@ -375,7 +375,10 @@ class FmhaKernel:
# softmax warps are 128 threads (4 warps × 32)
# Compute which softmax thread this is (0-127)
softmax_thread_idx = sfw_idx * 32 + (thread_idx % 32)
# thread_idx is (x,y,z) tuple, take x component
thread_x = thread_idx[0]
# sfw_idx is warp index within softmax group (0-3)
softmax_thread_idx = sfw_idx * 32 + (thread_x % 32)
print(f"[SMEM-P MANUAL] Softmax thread idx: {softmax_thread_idx}")
# Each thread handles 128 P values starting at index * 128