SMEM-P: fix thread_idx tuple access
This commit is contained in:
@@ -375,7 +375,10 @@ class FmhaKernel:
|
||||
# softmax warps are 128 threads (4 warps × 32)
|
||||
|
||||
# Compute which softmax thread this is (0-127)
|
||||
softmax_thread_idx = sfw_idx * 32 + (thread_idx % 32)
|
||||
# thread_idx is (x,y,z) tuple, take x component
|
||||
thread_x = thread_idx[0]
|
||||
# sfw_idx is warp index within softmax group (0-3)
|
||||
softmax_thread_idx = sfw_idx * 32 + (thread_x % 32)
|
||||
print(f"[SMEM-P MANUAL] Softmax thread idx: {softmax_thread_idx}")
|
||||
|
||||
# Each thread handles 128 P values starting at index * 128
|
||||
|
||||
Reference in New Issue
Block a user