From 8b1ac380acecd774eddca846144e105d24ffeb00 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 03:45:05 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20HD=3D512=20support=20=E2=80=94=20TMEM?= =?UTF-8?q?=5FN=3D512,=20test=20variants=20for=20all=20three=20TMA=20kerne?= =?UTF-8?q?ls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha_6warp_tma.cuh | 2 +- dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh | 2 +- dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh | 2 +- tests/unit/test_fmha_6warp_tma_hd512.cu | 2 ++ tests/unit/test_fmha_6warp_tma_multirow_hd512.cu | 2 ++ tests/unit/test_fmha_6warp_tma_multitile_hd512.cu | 2 ++ 6 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_fmha_6warp_tma_hd512.cu create mode 100644 tests/unit/test_fmha_6warp_tma_multirow_hd512.cu create mode 100644 tests/unit/test_fmha_6warp_tma_multitile_hd512.cu diff --git a/dsv4/kernels/attention/fmha_6warp_tma.cuh b/dsv4/kernels/attention/fmha_6warp_tma.cuh index 5abb50bb..cafd7753 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma.cuh @@ -48,7 +48,7 @@ fmha_6warp_tma_kernel( static constexpr int N_NSUB = HD / 16; static constexpr int TILE_SZ = 128 * MMA_K_BF16; static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; - static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512; static constexpr int MAX_ROWS = 128; static constexpr int CORES_MN = 128 / 8; static constexpr int NUM_READS = SK_TILE / 8; diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh index 8c5e9387..f54277eb 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh @@ -39,7 +39,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { static constexpr int N_NSUB = HD / 16; static constexpr int TILE_SZ = 128 * MMA_K_BF16; static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; - static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512; static constexpr int MAX_ROWS = 128; static constexpr int CORES_MN = 128 / 8; static constexpr int NUM_READS = SK_TILE / 8; diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh index 643f18c6..712bb7da 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh @@ -61,7 +61,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) { static constexpr int N_NSUB = HD / 16; static constexpr int TILE_SZ = 128 * MMA_K_BF16; static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; - static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + static constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512; static constexpr int CORES_MN = 128 / 8; static constexpr int NUM_READS = SK_TILE / 8; static constexpr int TMA_TILE_BYTES = TILE_SZ * sizeof(bf16_t); diff --git a/tests/unit/test_fmha_6warp_tma_hd512.cu b/tests/unit/test_fmha_6warp_tma_hd512.cu new file mode 100644 index 00000000..559a275a --- /dev/null +++ b/tests/unit/test_fmha_6warp_tma_hd512.cu @@ -0,0 +1,2 @@ +#define HD_VAL 512 +#include "test_fmha_6warp_tma.cu" diff --git a/tests/unit/test_fmha_6warp_tma_multirow_hd512.cu b/tests/unit/test_fmha_6warp_tma_multirow_hd512.cu new file mode 100644 index 00000000..b86d04da --- /dev/null +++ b/tests/unit/test_fmha_6warp_tma_multirow_hd512.cu @@ -0,0 +1,2 @@ +#define HD_VAL 512 +#include "test_fmha_6warp_tma_multirow.cu" diff --git a/tests/unit/test_fmha_6warp_tma_multitile_hd512.cu b/tests/unit/test_fmha_6warp_tma_multitile_hd512.cu new file mode 100644 index 00000000..c789b951 --- /dev/null +++ b/tests/unit/test_fmha_6warp_tma_multitile_hd512.cu @@ -0,0 +1,2 @@ +#define HD_VAL 512 +#include "test_fmha_6warp_tma_multitile.cu"