feat: HD=512 support — TMEM_N=512, test variants for all three TMA kernels
This commit is contained in:
@@ -48,7 +48,7 @@ fmha_6warp_tma_kernel(
|
||||
static constexpr int N_NSUB = HD / 16;
|
||||
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
|
||||
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512;
|
||||
static constexpr int MAX_ROWS = 128;
|
||||
static constexpr int CORES_MN = 128 / 8;
|
||||
static constexpr int NUM_READS = SK_TILE / 8;
|
||||
|
||||
@@ -39,7 +39,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
static constexpr int N_NSUB = HD / 16;
|
||||
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
|
||||
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512;
|
||||
static constexpr int MAX_ROWS = 128;
|
||||
static constexpr int CORES_MN = 128 / 8;
|
||||
static constexpr int NUM_READS = SK_TILE / 8;
|
||||
|
||||
@@ -61,7 +61,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) {
|
||||
static constexpr int N_NSUB = HD / 16;
|
||||
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
|
||||
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512;
|
||||
static constexpr int CORES_MN = 128 / 8;
|
||||
static constexpr int NUM_READS = SK_TILE / 8;
|
||||
static constexpr int TMA_TILE_BYTES = TILE_SZ * sizeof(bf16_t);
|
||||
|
||||
2
tests/unit/test_fmha_6warp_tma_hd512.cu
Normal file
2
tests/unit/test_fmha_6warp_tma_hd512.cu
Normal file
@@ -0,0 +1,2 @@
|
||||
#define HD_VAL 512
|
||||
#include "test_fmha_6warp_tma.cu"
|
||||
2
tests/unit/test_fmha_6warp_tma_multirow_hd512.cu
Normal file
2
tests/unit/test_fmha_6warp_tma_multirow_hd512.cu
Normal file
@@ -0,0 +1,2 @@
|
||||
#define HD_VAL 512
|
||||
#include "test_fmha_6warp_tma_multirow.cu"
|
||||
2
tests/unit/test_fmha_6warp_tma_multitile_hd512.cu
Normal file
2
tests/unit/test_fmha_6warp_tma_multitile_hd512.cu
Normal file
@@ -0,0 +1,2 @@
|
||||
#define HD_VAL 512
|
||||
#include "test_fmha_6warp_tma_multitile.cu"
|
||||
Reference in New Issue
Block a user