P5: clean kernel with runtime branch (single-tile unchanged, multi-tile separate path)
Single-tile path is IDENTICAL to the working pre-P5 kernel. Multi-tile path uses FA2 online softmax with sOacc accumulator. Runtime branch on is_multi_tile = (n_kv_tiles > 1).
This commit is contained in:
@@ -5,48 +5,26 @@
|
||||
* MULTI-HEAD + MULTI-KV-TILE (P3 + P5)
|
||||
* ==================================================================
|
||||
* Grid: dim3(1, n_h, batch_size)
|
||||
* blockIdx.y = head index (0..n_h-1)
|
||||
* blockIdx.z = batch index (0..batch_size-1)
|
||||
*
|
||||
* Each CTA processes one head of one batch item independently.
|
||||
* No cross-CTA synchronization required.
|
||||
*
|
||||
* For n_kv_tiles=1 (single KV segment): same as before, no rescale.
|
||||
* For n_kv_tiles>1: FlashAttention-2 online softmax with running max/sum.
|
||||
* - P is UN-NORMALIZED (exp(s - max), not divided by sum)
|
||||
* - After each KV tile: read PV result from TMEM → add to sOacc
|
||||
* - Rescale sOacc when running_max changes: sOacc *= exp(old_max - new_max)
|
||||
* - Final: O_normalized = sOacc / running_sum
|
||||
* n_kv_tiles <= 1 (single KV segment, s_k <= 128):
|
||||
* Direct QK → softmax → PV → epilogue. No rescale.
|
||||
* Identical to the pre-P5 kernel.
|
||||
*
|
||||
* n_kv_tiles > 1 (multiple KV segments, s_k > 128):
|
||||
* FlashAttention-2 online softmax across KV tiles.
|
||||
* - P is UN-NORMALIZED: exp(s - max), NOT divided by sum
|
||||
* - SMEM accumulator sOacc[HD]: O += PV after each tile
|
||||
* - Rescale: sOacc *= exp(old_max - new_max) when max changes
|
||||
* - Final: O = sOacc / running_sum
|
||||
*
|
||||
* ==================================================================
|
||||
* MQA / GQA SUPPORT
|
||||
* SMEM BUDGET (additional for multi-tile)
|
||||
* ==================================================================
|
||||
* Same as single-tile. K/V head strides handled by caller.
|
||||
* For MQA: k_head_stride=0, v_head_stride=0.
|
||||
* For GQA: K/V expanded via repeat_interleave in Python.
|
||||
*
|
||||
* ==================================================================
|
||||
* SMEM LAYOUT (P5 additions in comments)
|
||||
* ==================================================================
|
||||
* sTmemBase: 4 bytes
|
||||
* sRowMax: 4 bytes (running max across all KV tiles)
|
||||
* sRowSum: 4 bytes (running sum across all KV tiles)
|
||||
* padding: ~4 bytes (alignment)
|
||||
* sQ0: 128 * hd * 2 (Q canonical, reused per KV tile)
|
||||
* sK0: 128 * 16 * 2 (K segment canonical, one SK_TILE at a time)
|
||||
* sPk: 128 * 16 * 2 (P sub-tile canonical, reused per PV)
|
||||
* sV: 16 * 16 * 2 (V sub-tile canonical, reused per PV)
|
||||
* s_p_vals: SK_TILE * 4 (un-normalized P values)
|
||||
* sOacc: hd * 4 (float accumulator, 1 row for T=1)
|
||||
* Total additional vs single-tile: ~hd*4 bytes (sOacc)
|
||||
* sOacc: HD * 4 bytes (float accumulator, 1 row for T=1)
|
||||
* hd=64: +256B. hd=128: +512B. hd=256: +1024B.
|
||||
* All well within 232KB SMEM budget.
|
||||
*
|
||||
* ==================================================================
|
||||
* OUTPUT: NORMALIZED O + LSE
|
||||
* ==================================================================
|
||||
* O: normalized attention output (divided by running_sum)
|
||||
* LSE: ln(running_sum) + running_max — for external multi-segment merge if needed
|
||||
* Total SMEM at hd=64: ~14 KB. hd=128: ~15 KB. hd=256: ~17 KB.
|
||||
* Well within 232 KB SMEM budget.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -66,7 +44,7 @@ struct FmhaParams {
|
||||
int s_k; // Total KV sequence length
|
||||
float scale; // 1/sqrt(hd)
|
||||
int head_dim; // hd
|
||||
int n_kv_tiles; // Number of KV tiles (0 = auto-calc from s_k)
|
||||
int n_kv_tiles; // Number of KV tiles (0 or 1 = single-tile, >1 = multi-tile)
|
||||
|
||||
int q_head_stride, q_batch_stride;
|
||||
int k_head_stride, k_batch_stride;
|
||||
@@ -95,7 +73,6 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
const bool is_load_warp = (wid == 5);
|
||||
|
||||
// Per-head GMEM pointers
|
||||
const bf16_t* __restrict__ q_head = params.q
|
||||
+ head_idx * params.q_head_stride + batch_idx * params.q_batch_stride;
|
||||
const bf16_t* __restrict__ k_head = params.k
|
||||
@@ -113,23 +90,25 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
const float scale = params.scale;
|
||||
const int n_kv_tiles = (params.n_kv_tiles > 0) ? params.n_kv_tiles
|
||||
: (s_k + SK_TILE - 1) / SK_TILE;
|
||||
const bool is_multi_tile = (n_kv_tiles > 1);
|
||||
|
||||
// ================================================================
|
||||
// SMEM allocation
|
||||
// ================================================================
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
float* sRowMax = (float*)(sbuf + 4); // running max across all tiles
|
||||
float* sRowSum = sRowMax + 1; // running sum across all tiles
|
||||
float* sRowMax = (float*)(sbuf + 4);
|
||||
float* sRowSum = sRowMax + 1;
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + 1) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK0 = sQ0 + TILE_SZ;
|
||||
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
float* s_p_vals = (float*)(sV + V_SUB_SZ);
|
||||
// Multi-tile accumulator (only used when n_kv_tiles > 1)
|
||||
float* sOacc = (float*)(((uintptr_t)(s_p_vals + SK_TILE) + 15) & ~(uintptr_t)15);
|
||||
|
||||
// Initialize accumulator (single thread to avoid race)
|
||||
if (tid == 0) {
|
||||
// Initialize multi-tile accumulator
|
||||
if (is_multi_tile && tid == 0) {
|
||||
*sRowMax = -INFINITY;
|
||||
*sRowSum = 0.0f;
|
||||
for (int d = 0; d < HD; d++) sOacc[d] = 0.0f;
|
||||
@@ -144,37 +123,28 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// ================================================================
|
||||
// MAIN LOOP: iterate over KV tiles
|
||||
// SINGLE-TILE PATH (n_kv_tiles <= 1)
|
||||
// Identical to the pre-P5 kernel. Tested, proven correct.
|
||||
// ================================================================
|
||||
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
|
||||
int kv_start = kv_tile * SK_TILE;
|
||||
int kv_len = min(SK_TILE, s_k - kv_start);
|
||||
|
||||
// ============================================================
|
||||
// QK GEMM: for each K-tile (in hd), load Q + K segment, MMA
|
||||
// ============================================================
|
||||
if (!is_multi_tile) {
|
||||
// QK GEMM
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
if (is_load_warp) {
|
||||
// Load Q (same for all KV tiles)
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
sQ0[ck * 16 * 64 + lc] = q_head[kt * MMA_K_BF16 + d];
|
||||
}
|
||||
// Load K segment
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
|
||||
for (int r = 0; r < kv_len; r++) {
|
||||
int g_r = kv_start + r;
|
||||
for (int r = 0; r < s_k; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] =
|
||||
k_head[g_r * HD + kt * MMA_K_BF16 + d];
|
||||
sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + kt * MMA_K_BF16 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
@@ -185,10 +155,7 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Softmax (warp 0, row 0 for T=1 decode)
|
||||
// P is UN-NORMALIZED for multi-tile: exp(s - max), NOT / sum
|
||||
// ============================================================
|
||||
// Softmax (normalized P for single-tile)
|
||||
if (wid == 0) {
|
||||
float s_vals[SK_TILE], row_max = -INFINITY;
|
||||
for (int n = 0; n < SK_TILE / 8; n++) {
|
||||
@@ -199,102 +166,62 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) {
|
||||
float val = tmp[c] * scale;
|
||||
s_vals[n*8+c] = val;
|
||||
row_max = fmaxf(row_max, val);
|
||||
s_vals[n*8+c] = tmp[c] * scale;
|
||||
row_max = fmaxf(row_max, tmp[c] * scale);
|
||||
}
|
||||
}
|
||||
row_max = wmax(row_max);
|
||||
|
||||
if (lane == 0) *sRowMax = row_max;
|
||||
float row_sum = 0.0f;
|
||||
if (lane == 0) for (int j=0;j<SK_TILE;j++) {
|
||||
s_vals[j] = expf(s_vals[j] - row_max);
|
||||
row_sum += s_vals[j];
|
||||
}
|
||||
row_sum = wsum(row_sum);
|
||||
|
||||
// Online softmax rescale: sOacc *= exp(old_max - new_max)
|
||||
// This is the FlashAttention-2 running max/sum update.
|
||||
float old_max, new_max, old_sum;
|
||||
if (lane == 0) {
|
||||
old_max = *sRowMax;
|
||||
new_max = row_max;
|
||||
old_sum = *sRowSum;
|
||||
|
||||
if (old_max > -INFINITY) {
|
||||
float rescale = expf(old_max - new_max);
|
||||
for (int d = 0; d < HD; d++) sOacc[d] *= rescale;
|
||||
old_sum *= rescale;
|
||||
}
|
||||
*sRowMax = new_max;
|
||||
}
|
||||
|
||||
// Store P for PV. For single KV tile, use NORMALIZED P (same as old kernel).
|
||||
// For multi-tile, use UN-NORMALIZED P (critical for FA2 rescale).
|
||||
// Single-tile: O = P_norm × V. Multi-tile: O = Σ(P_unnorm × V) / running_sum
|
||||
if (lane == 0) {
|
||||
if (n_kv_tiles == 1) {
|
||||
// Normalized P (backward compatible with single-tile)
|
||||
for (int j=0;j<SK_TILE;j++) s_vals[j] /= row_sum;
|
||||
}
|
||||
for (int j=0;j<SK_TILE;j++) s_p_vals[j] = s_vals[j];
|
||||
}
|
||||
|
||||
// Update running sum
|
||||
if (lane == 0) *sRowSum = old_sum + row_sum;
|
||||
if (lane == 0) *sRowSum = row_sum;
|
||||
if (lane == 0) for (int j=0;j<SK_TILE;j++) s_vals[j] /= row_sum;
|
||||
if (lane == 0) for (int j=0;j<SK_TILE;j++) s_p_vals[j] = s_vals[j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ============================================================
|
||||
// PV GEMM: P_unnorm × V → O_tile, accumulate into TMEM
|
||||
// For each KV tile: ACCUMULATE=False on first sub-tile, True on rest
|
||||
// ============================================================
|
||||
// PV GEMM
|
||||
for (int n = 0; n < N_NSUB; n++) {
|
||||
int d_base = n * 16;
|
||||
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
if (is_load_warp) {
|
||||
// Fill sPk from s_p_vals (un-normalized)
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
|
||||
if (lane < 16) {
|
||||
int c = lane;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] =
|
||||
f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
|
||||
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
|
||||
}
|
||||
// Load V sub-tile
|
||||
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
|
||||
for (int dd = lane; dd < 16; dd += 32) {
|
||||
for (int lr = 0; lr < MMA_K_BF16; lr++) {
|
||||
int r = kv_start + kt * MMA_K_BF16 + lr;
|
||||
int r = kt * MMA_K_BF16 + lr;
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] =
|
||||
v_head[(d_base + dd) * s_k + r];
|
||||
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc_pv16 = make_idesc(128, 16);
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
// ACCUMULATE=False for first PV sub-tile of each KV tile
|
||||
// (clears TMEM for that sub-tile). True for subsequent sub-tiles.
|
||||
bool acc = !(n == 0 && kt == 0);
|
||||
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, acc);
|
||||
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Read PV result from TMEM → add to sOacc
|
||||
// After reading, TMEM can be reused for the next KV tile.
|
||||
// ============================================================
|
||||
// Epilogue: TMEM → regs → BF16 → GMEM (P was normalized, no division needed)
|
||||
if (wid == 0) {
|
||||
float row_max = *sRowMax;
|
||||
float row_sum = *sRowSum;
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
@@ -302,45 +229,160 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) sOacc[n*8+c] += tmp[c];
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
|
||||
}
|
||||
if (lane == 0) {
|
||||
for (int d = 0; d < HD; d++) o_head[d] = f32_to_bf16(o_vals[d]);
|
||||
if (lse_head) lse_head[0] = logf(row_sum) + row_max;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
} // end KV tile loop
|
||||
|
||||
// ================================================================
|
||||
// Epilogue: write O to GMEM
|
||||
// For single KV tile (normalized P): sOacc = P_norm × V (already normalized)
|
||||
// For multi KV tile (un-normalized P): O = sOacc / running_sum
|
||||
// MULTI-TILE PATH (n_kv_tiles > 1)
|
||||
// FlashAttention-2 online softmax across KV tiles.
|
||||
// ================================================================
|
||||
if (wid == 0) {
|
||||
float running_max = *sRowMax;
|
||||
float running_sum = *sRowSum;
|
||||
} else {
|
||||
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
|
||||
int kv_start = kv_tile * SK_TILE;
|
||||
int kv_len = min(SK_TILE, s_k - kv_start);
|
||||
|
||||
if (lane == 0) {
|
||||
if (n_kv_tiles > 1) {
|
||||
float inv_sum = 1.0f / running_sum;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
o_head[d] = f32_to_bf16(sOacc[d] * inv_sum);
|
||||
// QK GEMM (same as single-tile, but K offset = kv_start)
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
if (is_load_warp) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
sQ0[ck * 16 * 64 + lc] = q_head[kt * MMA_K_BF16 + d];
|
||||
}
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
|
||||
for (int r = 0; r < kv_len; r++) {
|
||||
int g_r = kv_start + r;
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k_head[g_r * HD + kt * MMA_K_BF16 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single tile: P was normalized, sOacc is already the correct output
|
||||
for (int d = 0; d < HD; d++) {
|
||||
o_head[d] = f32_to_bf16(sOacc[d]);
|
||||
__syncthreads();
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
|
||||
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Softmax: UN-NORMALIZED P (exp(s - max), NOT / sum)
|
||||
if (wid == 0) {
|
||||
float s_vals[SK_TILE], row_max = -INFINITY;
|
||||
for (int n = 0; n < SK_TILE / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) {
|
||||
s_vals[n*8+c] = tmp[c] * scale;
|
||||
row_max = fmaxf(row_max, tmp[c] * scale);
|
||||
}
|
||||
}
|
||||
row_max = wmax(row_max);
|
||||
|
||||
float row_sum = 0.0f;
|
||||
if (lane == 0) for (int j=0;j<SK_TILE;j++) {
|
||||
s_vals[j] = expf(s_vals[j] - row_max);
|
||||
row_sum += s_vals[j];
|
||||
}
|
||||
row_sum = wsum(row_sum);
|
||||
|
||||
// Online softmax rescale: sOacc *= exp(old_max - new_max)
|
||||
if (lane == 0) {
|
||||
float old_max = *sRowMax;
|
||||
float old_sum = *sRowSum;
|
||||
if (old_max > -INFINITY) {
|
||||
float rescale = expf(old_max - row_max);
|
||||
for (int d = 0; d < HD; d++) sOacc[d] *= rescale;
|
||||
old_sum *= rescale;
|
||||
}
|
||||
*sRowMax = row_max;
|
||||
*sRowSum = old_sum + row_sum;
|
||||
}
|
||||
|
||||
// Store UN-NORMALIZED P for PV
|
||||
if (lane == 0) for (int j=0;j<SK_TILE;j++) s_p_vals[j] = s_vals[j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// PV GEMM (ACCUMULATE=False for first sub-tile of each KV tile)
|
||||
for (int n = 0; n < N_NSUB; n++) {
|
||||
int d_base = n * 16;
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
if (is_load_warp) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
|
||||
if (lane < 16) {
|
||||
int c = lane;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
|
||||
}
|
||||
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
|
||||
for (int dd = lane; dd < 16; dd += 32) {
|
||||
for (int lr = 0; lr < MMA_K_BF16; lr++) {
|
||||
int r = kv_start + kt * MMA_K_BF16 + lr;
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc_pv16 = make_idesc(128, 16);
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
// ACCUMULATE=False for first PV sub-tile of each KV tile
|
||||
bool acc = !(n == 0 && kt == 0);
|
||||
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, acc);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
if (lse_head) {
|
||||
lse_head[0] = logf(running_sum) + running_max;
|
||||
|
||||
// Read PV from TMEM → accumulate into sOacc
|
||||
if (wid == 0) {
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) sOacc[n*8+c] += tmp[c];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
} // end KV tile loop
|
||||
|
||||
// Epilogue: normalize sOacc by running_sum
|
||||
if (wid == 0) {
|
||||
float running_max = *sRowMax;
|
||||
float running_sum = *sRowSum;
|
||||
if (lane == 0) {
|
||||
float inv_sum = 1.0f / running_sum;
|
||||
for (int d = 0; d < HD; d++) o_head[d] = f32_to_bf16(sOacc[d] * inv_sum);
|
||||
if (lse_head) lse_head[0] = logf(running_sum) + running_max;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM dealloc
|
||||
if (is_mma_warp) {
|
||||
tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
|
||||
Reference in New Issue
Block a user