auto: pre-test commit
This commit is contained in:
@@ -41,7 +41,7 @@ test_fmha_hd64_n16(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
bf16_t* sK0 = sQ0 + NKT_QK * TILE_SZ;
|
||||
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + NKT_QK * TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
float* s_p_vals = (float*)(sV + NKT_PV * V_SUB_TILE_SZ); // Only 1 V sub-tile at a time
|
||||
float* s_p_vals = (float*)(sV + V_SUB_TILE_SZ); // Only 1 V sub-tile at a time
|
||||
|
||||
// Load Q K-tiles (4 × (128,16) canonical)
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
@@ -234,6 +234,9 @@ int main() {
|
||||
cudaFuncSetAttribute(test_fmha_hd64_n16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
test_fmha_hd64_n16<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE);
|
||||
|
||||
cudaError_t launch_err = cudaGetLastError();
|
||||
if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); return 1; }
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user