feat: TMA loads for both K and V in 6-warp FMHA kernel
This commit is contained in:
@@ -49,7 +49,7 @@ __global__ void __launch_bounds__(192)
|
||||
fmha_6warp_tma_kernel(
|
||||
const bf16_t* __restrict__ q,
|
||||
CUtensorMap* __restrict__ tma_k,
|
||||
const bf16_t* __restrict__ v,
|
||||
CUtensorMap* __restrict__ tma_v,
|
||||
bf16_t* __restrict__ o,
|
||||
float* __restrict__ lse,
|
||||
int s_k,
|
||||
@@ -206,25 +206,35 @@ fmha_6warp_tma_kernel(
|
||||
int d_base = n * 16;
|
||||
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
// ---- Warp 5: Fill sPk and load V sub-tile ----
|
||||
// ---- Warp 5: Fill sPk ----
|
||||
if (is_load_warp) {
|
||||
// Fill sPk from s_p_vals
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
|
||||
if (lane < 16) {
|
||||
int c = lane;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
sPk[ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
|
||||
}
|
||||
// Load V sub-tile: direct from GMEM (same as working fmha_6warp)
|
||||
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
|
||||
for (int dd = lane; dd < 16; dd += 32) {
|
||||
for (int lr = 0; lr < MMA_K_BF16; lr++) {
|
||||
int r = kt * MMA_K_BF16 + lr;
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v[(d_base + dd) * SK_TILE + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- TMA load V sub-tile ----
|
||||
// V is (HD, SK) in GMEM. Load (16, 16) tile at (d_base, kt*16)
|
||||
// TMA coord: (kt*16, d_base) = (inner_dim_offset, outer_dim_offset)
|
||||
if (is_load_warp && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_v,
|
||||
mbar_addr, kt * MMA_K_BF16, d_base);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, V_SUB_SZ * sizeof(bf16_t));
|
||||
}
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
__syncthreads();
|
||||
|
||||
// Convert sTmaBuf (row-major 16×16) → sV (canonical 16×16)
|
||||
for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0;
|
||||
for (int i = tid; i < 16 * MMA_K_BF16; i += 192) {
|
||||
int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16; // row, col in (16,16) tile
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = sTmaBuf[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -82,11 +82,19 @@ int main() {
|
||||
// TMA descriptor for K: (SK, HD) with tile (128, 16)
|
||||
CUtensorMap tma_k; CUtensorMap* d_tma_k;
|
||||
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16)) {
|
||||
printf("TMA desc FAILED\n"); return 1;
|
||||
printf("TMA K desc FAILED\n"); return 1;
|
||||
}
|
||||
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
// TMA descriptor for V: (HD, SK) with tile (16, 16)
|
||||
CUtensorMap tma_v; CUtensorMap* d_tma_v;
|
||||
if (!create_tma_desc_2d_bf16(&tma_v, d_v, HD, SK, 16, 16)) {
|
||||
printf("TMA V desc FAILED\n"); return 1;
|
||||
}
|
||||
cudaMalloc(&d_tma_v, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_tma_v, &tma_v, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
// Compute reference
|
||||
float o_ref[HD];
|
||||
{
|
||||
@@ -114,7 +122,7 @@ int main() {
|
||||
if (smem > 48 * 1024) {
|
||||
cudaFuncSetAttribute(fmha_6warp_tma_kernel<HD>, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem);
|
||||
}
|
||||
fmha_6warp_tma_kernel<HD><<<1, 192, smem>>>(d_q, d_tma_k, d_v, d_o, d_lse, SK, SCALE);
|
||||
fmha_6warp_tma_kernel<HD><<<1, 192, smem>>>(d_q, d_tma_k, d_tma_v, d_o, d_lse, SK, SCALE);
|
||||
|
||||
cudaError_t launch_err = cudaGetLastError();
|
||||
if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); return 1; }
|
||||
@@ -137,7 +145,7 @@ int main() {
|
||||
printf("Filtered cosine: %.8f\n", cs);
|
||||
printf("Test %s\n", cs > 0.999f ? "PASSED" : "FAILED");
|
||||
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k);
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k); cudaFree(d_tma_v);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse);
|
||||
return cs > 0.999f ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user