P5: fix TMA desc creation — write to HOST then cudaMemcpy to device
This commit is contained in:
@@ -48,15 +48,19 @@ int fmha_multitile_decode_launch(
|
||||
int idx = b * n_h + h;
|
||||
|
||||
// K: (N, hd), TMA tile (128, 16)
|
||||
if (!create_tma_desc_2d_bf16(d_tma_k + idx, k_head, N, hd, 128, 16)) {
|
||||
CUtensorMap h_desc;
|
||||
if (!create_tma_desc_2d_bf16(&h_desc, k_head, N, hd, 128, 16)) {
|
||||
cudaFree(d_tma_k); cudaFree(d_tma_v);
|
||||
return -1;
|
||||
}
|
||||
cudaMemcpy(d_tma_k + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
// V: (hd, N), TMA tile (16, 16)
|
||||
if (!create_tma_desc_2d_bf16(d_tma_v + idx, v_head, hd, N, 16, 16)) {
|
||||
if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N, 16, 16)) {
|
||||
cudaFree(d_tma_k); cudaFree(d_tma_v);
|
||||
return -1;
|
||||
}
|
||||
cudaMemcpy(d_tma_v + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user