fix: TMA tile strides must match global strides, not tile dimensions
The tile stride in the outer dimension should be the global row stride (cols), not the tile width. The tile is a window into the global tensor and elements are addressed with global strides.
This commit is contained in:
@@ -131,8 +131,11 @@ inline bool create_tma_desc_2d_bf16(
|
||||
uint64_t global_str[] = {1, cols};
|
||||
// Tile dimensions: (tile_cols, tile_rows) — innermost-first
|
||||
uint32_t tile_dim[] = {tile_cols, tile_rows};
|
||||
// Tile strides: (1, tile_cols)
|
||||
uint32_t tile_str[] = {1, tile_cols};
|
||||
// Tile strides: (1, cols) — MUST match global strides, not tile dimensions
|
||||
// The tile elements are accessed with the same stride as the global tensor.
|
||||
// Moving from row r to r+1 within the tile means jumping 'cols' elements
|
||||
// in the global tensor, regardless of tile width.
|
||||
uint32_t tile_str[] = {1, (uint32_t)cols};
|
||||
|
||||
CUresult res = cuTensorMapEncodeTiled(
|
||||
out,
|
||||
|
||||
Reference in New Issue
Block a user