[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
d0e186c16f
commit
b039bfda8f
@@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
|
||||
|
||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||
// fixed GROUP_SIZE of 128.
|
||||
static constexpr int GROUP_SIZE = 128;
|
||||
|
||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
||||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(input.size(-1) % 256 == 0);
|
||||
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
|
||||
|
||||
using Idx_t = int64_t;
|
||||
|
||||
@@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
|
||||
|
||||
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
||||
|
||||
static constexpr int GROUP_SIZE = 128;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
||||
@@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
|
||||
|
||||
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||
|
||||
int const NUM_GROUPS = H / GROUP_SIZE;
|
||||
if (!use_ue8m0) {
|
||||
if (H >= 4096) {
|
||||
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||
/* 8 warps config */
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
/* 1 warp config */
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
||||
}
|
||||
} else {
|
||||
if (H >= 4096) {
|
||||
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||
/* 8 warps config */
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
/* 1 warp config */
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user