test: volatile SMEM writes + 2 K-tiles
This commit is contained in:
@@ -42,15 +42,18 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k,
|
||||
__syncthreads();
|
||||
|
||||
// Write Q (1, hd) to sQ row 0 in canonical layout
|
||||
// Use volatile to prevent compiler optimization eliminating writes
|
||||
volatile bf16_t* vsQ = (volatile bf16_t*)sQ;
|
||||
volatile bf16_t* vsK = (volatile bf16_t*)sK;
|
||||
for (int d = tid; d < hd; d += 128) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
sQ[ck * 16 * 64 + lc] = q[d];
|
||||
vsQ[ck * 16 * 64 + lc] = q[d];
|
||||
}
|
||||
// Write K (sk, hd) to sK in canonical layout
|
||||
for (int i = tid; i < sk * hd; i += 128) {
|
||||
int r = i / hd, c = i % hd;
|
||||
int tmn = r / 8, ck = c / 8, lr = r % 8, lc = c % 8;
|
||||
sK[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[i];
|
||||
vsK[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user