feat: HD=512 support — TMEM_N=512, test variants for all three TMA kernels

This commit is contained in:
2026-05-30 03:45:05 +00:00
parent 762f054d6d
commit 8b1ac380ac
6 changed files with 9 additions and 3 deletions

View File

@@ -48,7 +48,7 @@ fmha_6warp_tma_kernel(
static constexpr int N_NSUB = HD / 16;
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512;
static constexpr int MAX_ROWS = 128;
static constexpr int CORES_MN = 128 / 8;
static constexpr int NUM_READS = SK_TILE / 8;

View File

@@ -39,7 +39,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
static constexpr int N_NSUB = HD / 16;
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512;
static constexpr int MAX_ROWS = 128;
static constexpr int CORES_MN = 128 / 8;
static constexpr int NUM_READS = SK_TILE / 8;

View File

@@ -61,7 +61,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) {
static constexpr int N_NSUB = HD / 16;
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512;
static constexpr int CORES_MN = 128 / 8;
static constexpr int NUM_READS = SK_TILE / 8;
static constexpr int TMA_TILE_BYTES = TILE_SZ * sizeof(bf16_t);

View File

@@ -0,0 +1,2 @@
#define HD_VAL 512
#include "test_fmha_6warp_tma.cu"

View File

@@ -0,0 +1,2 @@
#define HD_VAL 512
#include "test_fmha_6warp_tma_multirow.cu"

View File

@@ -0,0 +1,2 @@
#define HD_VAL 512
#include "test_fmha_6warp_tma_multitile.cu"