fix: TMA tile strides must match global strides, not tile dimensions

The tile stride in the outer dimension should be the global row stride
(cols), not the tile width. The tile is a window into the global tensor
and elements are addressed with global strides.
This commit is contained in:
2026-05-29 04:41:53 +00:00
parent 409838ace2
commit 3412ff1a9b

View File

@@ -131,8 +131,11 @@ inline bool create_tma_desc_2d_bf16(
uint64_t global_str[] = {1, cols};
// Tile dimensions: (tile_cols, tile_rows) — innermost-first
uint32_t tile_dim[] = {tile_cols, tile_rows};
// Tile strides: (1, tile_cols)
uint32_t tile_str[] = {1, tile_cols};
// Tile strides: (1, cols) — MUST match global strides, not tile dimensions
// The tile elements are accessed with the same stride as the global tensor.
// Moving from row r to r+1 within the tile means jumping 'cols' elements
// in the global tensor, regardless of tile width.
uint32_t tile_str[] = {1, (uint32_t)cols};
CUresult res = cuTensorMapEncodeTiled(
out,