P5: fix TMA desc creation — write to HOST then cudaMemcpy to device

This commit is contained in:
2026-05-30 10:40:01 +00:00
parent f370bfb1f1
commit a2627359fb

View File

@@ -48,15 +48,19 @@ int fmha_multitile_decode_launch(
int idx = b * n_h + h;
// K: (N, hd), TMA tile (128, 16)
if (!create_tma_desc_2d_bf16(d_tma_k + idx, k_head, N, hd, 128, 16)) {
CUtensorMap h_desc;
if (!create_tma_desc_2d_bf16(&h_desc, k_head, N, hd, 128, 16)) {
cudaFree(d_tma_k); cudaFree(d_tma_v);
return -1;
}
cudaMemcpy(d_tma_k + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
// V: (hd, N), TMA tile (16, 16)
if (!create_tma_desc_2d_bf16(d_tma_v + idx, v_head, hd, N, 16, 16)) {
if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N, 16, 16)) {
cudaFree(d_tma_k); cudaFree(d_tma_v);
return -1;
}
cudaMemcpy(d_tma_v + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
}
}