diff --git a/dsv4/kernels/attention/fmha_multitile_capi.cu b/dsv4/kernels/attention/fmha_multitile_capi.cu index 85597714..24b57010 100644 --- a/dsv4/kernels/attention/fmha_multitile_capi.cu +++ b/dsv4/kernels/attention/fmha_multitile_capi.cu @@ -48,15 +48,19 @@ int fmha_multitile_decode_launch( int idx = b * n_h + h; // K: (N, hd), TMA tile (128, 16) - if (!create_tma_desc_2d_bf16(d_tma_k + idx, k_head, N, hd, 128, 16)) { + CUtensorMap h_desc; + if (!create_tma_desc_2d_bf16(&h_desc, k_head, N, hd, 128, 16)) { cudaFree(d_tma_k); cudaFree(d_tma_v); return -1; } + cudaMemcpy(d_tma_k + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + // V: (hd, N), TMA tile (16, 16) - if (!create_tma_desc_2d_bf16(d_tma_v + idx, v_head, hd, N, 16, 16)) { + if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N, 16, 16)) { cudaFree(d_tma_k); cudaFree(d_tma_v); return -1; } + cudaMemcpy(d_tma_v + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice); } }