[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-10 12:21:52 -05:00
committed by GitHub
parent d0e186c16f
commit b039bfda8f
3 changed files with 16 additions and 7 deletions

View File

@@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128;
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
using Idx_t = int64_t;
@@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
Idx_t stride_counts_e = tokens_per_expert.stride(0);
static constexpr int GROUP_SIZE = 128;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
@@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
int const NUM_GROUPS = H / GROUP_SIZE;
if (!use_ue8m0) {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}