[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
fdfd5075aa
commit
fe1cd7704d
@@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens,
|
||||
}
|
||||
|
||||
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
|
||||
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
|
||||
int NUM_STAGES = 3>
|
||||
typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0,
|
||||
int GROUP_SIZE = 128, int NUM_STAGES = 3>
|
||||
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
||||
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
|
||||
scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
|
||||
// sizes
|
||||
Idx_t E, Idx_t T, Idx_t H,
|
||||
// strides (in elements)
|
||||
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
||||
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
||||
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
||||
Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) {
|
||||
#ifndef USE_ROCM
|
||||
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
|
||||
|
||||
@@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
|
||||
__nv_fp8x4_e4m3* y_q_base_ptr =
|
||||
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
|
||||
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
|
||||
|
||||
Idx_t scale_group_offset = 0;
|
||||
if constexpr (std::is_same<scale_t, uint8_t>::value) {
|
||||
// packed int32_t format
|
||||
int pack_id = warp_position_scales / 4;
|
||||
int scale_in_pack = warp_position_scales % 4;
|
||||
scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g;
|
||||
} else {
|
||||
scale_group_offset = warp_position_scales * stride_ys_g;
|
||||
}
|
||||
|
||||
scale_t* const y_scale_base_ptr = _y_s + scale_group_offset;
|
||||
|
||||
for (auto j = tokens_lower; j < tokens_upper; j++) {
|
||||
int current_group_id = warp_position_scales; // Running count of which
|
||||
// group is being processed
|
||||
const Idx_t base_ys = expert_id * stride_ys_e;
|
||||
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
|
||||
__nv_fp8x4_e4m3* y_q_ptr =
|
||||
@@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
|
||||
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
|
||||
|
||||
if constexpr (USE_UE8M0) {
|
||||
if constexpr (CEIL_UE8M0) {
|
||||
y_s = hexp2(hceil(hlog2(y_s)));
|
||||
}
|
||||
|
||||
@@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
y_q_ptr += WARP_SIZE * stride_yq_h;
|
||||
|
||||
if (!lane_id) {
|
||||
*y_s_ptr = y_s;
|
||||
y_s_ptr += stride_ys_g;
|
||||
// Store scales.
|
||||
if constexpr (std::is_same<scale_t, uint8_t>::value) {
|
||||
// Packed UE8MO format. Remove Mantissa.
|
||||
*y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7;
|
||||
|
||||
bool const jump_pack = (current_group_id + 1) % 4 == 0;
|
||||
// Minus 3 because we need to get to the first group in the
|
||||
// next pack.
|
||||
y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g;
|
||||
|
||||
} else {
|
||||
// float32 format
|
||||
static_assert(std::is_same<scale_t, float>::value);
|
||||
*y_s_ptr = y_s;
|
||||
y_s_ptr += stride_ys_g;
|
||||
}
|
||||
|
||||
current_group_id += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant(
|
||||
const at::Tensor& tokens_per_expert, // (E)
|
||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
bool use_ue8m0) {
|
||||
bool cast_scale_ue8m0) {
|
||||
#ifndef USE_ROCM
|
||||
|
||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||
@@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant(
|
||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
||||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
|
||||
|
||||
bool const is_packed_ue8m0 =
|
||||
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
|
||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
|
||||
|
||||
using Idx_t = int64_t;
|
||||
|
||||
Idx_t E = input.size(0);
|
||||
@@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant(
|
||||
Idx_t stride_yq_e = y_q.stride(0);
|
||||
Idx_t stride_yq_t = y_q.stride(1);
|
||||
Idx_t stride_yq_h = y_q.stride(2);
|
||||
Idx_t stride_ys_e = y_s.stride(0);
|
||||
Idx_t stride_ys_t = y_s.stride(1);
|
||||
Idx_t stride_ys_g = y_s.stride(2);
|
||||
|
||||
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
||||
|
||||
int const NUM_GROUPS = H / GROUP_SIZE;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
||||
// TODO: Get this from cuda_arch ?
|
||||
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||
|
||||
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
|
||||
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
|
||||
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
|
||||
int sms = SILU_V2_BLOCK_COUNT; \
|
||||
static constexpr int max_shared_mem_bytes = \
|
||||
@@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant(
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
|
||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
|
||||
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
|
||||
USE_UE8M0, GROUP_SIZE, STAGES> \
|
||||
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
|
||||
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
|
||||
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
|
||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||
(fp8_t*)y_q.data_ptr(), \
|
||||
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
|
||||
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
|
||||
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
|
||||
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
|
||||
stride_ys_g, stride_counts_e); \
|
||||
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
|
||||
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
|
||||
});
|
||||
|
||||
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
|
||||
STRIDE_YS_P, CEIL_UE8M0) \
|
||||
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
|
||||
/* 8 warp config */ \
|
||||
static constexpr int NUM_STAGES = 4; \
|
||||
static constexpr int THREAD_COUNT = 256; \
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
|
||||
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
|
||||
} else { \
|
||||
/* 1 warp config */ \
|
||||
static constexpr int THREAD_COUNT = 32; \
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
|
||||
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
|
||||
}
|
||||
|
||||
int const NUM_GROUPS = H / GROUP_SIZE;
|
||||
if (!use_ue8m0) {
|
||||
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||
/* 8 warps config */
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
/* 1 warp config */
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
||||
}
|
||||
} else {
|
||||
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||
/* 8 warps config */
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
/* 1 warp config */
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
||||
}
|
||||
Idx_t stride_ys_e = y_s.stride(0);
|
||||
Idx_t stride_ys_t = y_s.stride(1);
|
||||
Idx_t stride_ys_g = y_s.stride(2);
|
||||
Idx_t stride_ys_p = 0;
|
||||
if (!cast_scale_ue8m0) {
|
||||
TORCH_CHECK(!is_packed_ue8m0);
|
||||
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
|
||||
false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!is_packed_ue8m0) {
|
||||
// UE8M0 but not packed
|
||||
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
|
||||
true);
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
|
||||
TORCH_CHECK(y_s.dtype() == torch::kInt32);
|
||||
|
||||
// Int32 packed ue8m0 scales tensor.
|
||||
// Let E, T, G be the number to experts, number of tokens and number of groups
|
||||
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
|
||||
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
|
||||
// to be arranged as follows,
|
||||
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
|
||||
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
|
||||
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
|
||||
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
|
||||
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
|
||||
//
|
||||
// In memory (in bytes) the scale values are arranged as,
|
||||
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
|
||||
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
|
||||
// X, X, T3G4, T3G5, X, X]
|
||||
//
|
||||
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
|
||||
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
|
||||
// english, ignoring the Experts dimension, the original int32 tensor is
|
||||
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
|
||||
// tensor). The following strides setting reflects this change. Caveat: This
|
||||
// means that the G dimension is no longer contiguous. i.e. Note that to move
|
||||
// from G3 to G4, we need to jump along the packing dimension. The kernel
|
||||
// handles this case.
|
||||
|
||||
stride_ys_e *= sizeof(int32_t);
|
||||
stride_ys_p = T * sizeof(int32_t); // Packing dimension
|
||||
stride_ys_t = sizeof(int32_t);
|
||||
stride_ys_g = 1;
|
||||
|
||||
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
|
||||
true);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user