fix: remove duplicate TMEM_COLS_NEEDED declarations
This commit is contained in:
@@ -111,7 +111,7 @@ fmha_decode_tmem(
|
||||
|
||||
// Initialize TMEM O to zero — warp-collective
|
||||
// Use TMEM_COLS_NEEDED columns, each zeroed by all 32 lanes writing 0.
|
||||
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128;
|
||||
// TMEM_COLS_NEEDED defined above
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < TMEM_COLS_NEEDED; col++) {
|
||||
tmem_store(tmem_base + col, 0, 0, 0, 0);
|
||||
@@ -233,7 +233,7 @@ fmha_decode_tmem(
|
||||
// Lane i writes sPvBuf[i*4+0..3] to column 0.
|
||||
// Lanes with i*4 >= HD write zeros (don't-care, but must participate).
|
||||
//
|
||||
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; // 1 for HD<=128, 2 for HD<=256
|
||||
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < TMEM_COLS_NEEDED; col++) {
|
||||
int base = col * 128; // FP32 offset for this column
|
||||
@@ -256,7 +256,7 @@ fmha_decode_tmem(
|
||||
|
||||
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
|
||||
// Same lane mapping: lane i reads positions i*4+0..3 from the column.
|
||||
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128;
|
||||
|
||||
if (wid == 0) {
|
||||
float inv_sum = 1.0f / sRowSums[0];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user