From 3412ff1a9ba8836614994c03ee671db138a8752c Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 04:41:53 +0000 Subject: [PATCH] fix: TMA tile strides must match global strides, not tile dimensions The tile stride in the outer dimension should be the global row stride (cols), not the tile width. The tile is a window into the global tensor and elements are addressed with global strides. --- dsv4/kernels/attention/fmha_tma.cuh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha_tma.cuh b/dsv4/kernels/attention/fmha_tma.cuh index fa8589e3..bc7bacbf 100644 --- a/dsv4/kernels/attention/fmha_tma.cuh +++ b/dsv4/kernels/attention/fmha_tma.cuh @@ -131,8 +131,11 @@ inline bool create_tma_desc_2d_bf16( uint64_t global_str[] = {1, cols}; // Tile dimensions: (tile_cols, tile_rows) — innermost-first uint32_t tile_dim[] = {tile_cols, tile_rows}; - // Tile strides: (1, tile_cols) - uint32_t tile_str[] = {1, tile_cols}; + // Tile strides: (1, cols) — MUST match global strides, not tile dimensions + // The tile elements are accessed with the same stride as the global tensor. + // Moving from row r to r+1 within the tile means jumping 'cols' elements + // in the global tensor, regardless of tile width. + uint32_t tile_str[] = {1, (uint32_t)cols}; CUresult res = cuTensorMapEncodeTiled( out,