diff --git a/dsv4/kernels/attention/fmha_tma.cuh b/dsv4/kernels/attention/fmha_tma.cuh index fa8589e3..bc7bacbf 100644 --- a/dsv4/kernels/attention/fmha_tma.cuh +++ b/dsv4/kernels/attention/fmha_tma.cuh @@ -131,8 +131,11 @@ inline bool create_tma_desc_2d_bf16( uint64_t global_str[] = {1, cols}; // Tile dimensions: (tile_cols, tile_rows) — innermost-first uint32_t tile_dim[] = {tile_cols, tile_rows}; - // Tile strides: (1, tile_cols) - uint32_t tile_str[] = {1, tile_cols}; + // Tile strides: (1, cols) — MUST match global strides, not tile dimensions + // The tile elements are accessed with the same stride as the global tensor. + // Moving from row r to r+1 within the tile means jumping 'cols' elements + // in the global tensor, regardless of tile width. + uint32_t tile_str[] = {1, (uint32_t)cols}; CUresult res = cuTensorMapEncodeTiled( out,