[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement. (#29546)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-12-07 07:31:14 -05:00
committed by GitHub
parent b0f4866a77
commit 541a2ef892
5 changed files with 311 additions and 12 deletions

View File

@@ -206,6 +206,191 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
#undef LAUNCH_KERNEL
}
template <typename T, typename DST_DTYPE>
__global__ void per_token_group_quant_8bit_packed_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
unsigned int* __restrict__ output_s_packed, const int group_size,
const int num_groups, const int groups_per_block, const int groups_per_row,
const int mn, const int tma_aligned_mn, const float eps,
const float min_8bit, const float max_8bit) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
if (global_group_id >= num_groups) {
return;
}
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
// shared memory to cache each group's data to avoid double DRAM reads.
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
// pack 4 scales into a uint32
if (lane_id == 0) {
// map flat group id to 2D indices (mn_idx, sf_k_idx)
const int sf_k_idx = static_cast<int>(global_group_id % groups_per_row);
const int mn_idx = static_cast<int>(global_group_id / groups_per_row);
if (mn_idx < mn) {
// each uint32 in output_s_packed stores 4 packed scales
const int sf_k_pack_idx = sf_k_idx / 4;
const int pos = sf_k_idx % 4;
// reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
// exponent, and place it into the correct byte of the 32-bit word.
const unsigned int bits = __float_as_uint(y_s);
const unsigned int exponent = (bits >> 23u) & 0xffu;
const unsigned int contrib = exponent << (pos * 8u);
const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx;
// atomically OR 8-bit exponent into the packed scales buffer
atomicOr(output_s_packed + out_idx, contrib);
}
}
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s_packed,
int64_t group_size, double eps,
double min_8bit, double max_8bit) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
const int64_t k = input.size(-1);
TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
") must be divisible by group_size (", group_size, ").");
const int64_t mn = input.numel() / k;
const int64_t groups_per_row = k / group_size;
const int64_t num_groups = mn * groups_per_row;
TORCH_CHECK(output_s_packed.dim() == 2,
"output_s_packed must be 2D, got dim=", output_s_packed.dim(),
".");
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int,
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
// DeepGEMM expects SFA scales in MN-major form with shape
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
// dimension.
TORCH_CHECK(output_s_packed.size(0) == mn &&
output_s_packed.size(1) == k_num_packed_sfk,
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
"], but got [", output_s_packed.size(0), ", ",
output_s_packed.size(1), "].");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
// zero-initialize packed scales, since we use atomicOr to accumulate
// exponents from different groups.
output_s_packed.zero_();
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
size_t smem_bytes = \
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
per_token_group_quant_8bit_packed_kernel<T, DST_DTYPE> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
static_cast<int>(group_size), static_cast<int>(num_groups), \
groups_per_block, static_cast<int>(groups_per_row), \
static_cast<int>(mn), static_cast<int>(tma_aligned_mn), \
static_cast<float>(eps), static_cast<float>(min_8bit), \
static_cast<float>(max_8bit)); \
} while (0)
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == at::ScalarType::Char) {
LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
} else {
TORCH_CHECK(
false,
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
"outputs.");
}
}));
#undef LAUNCH_PACKED_KERNEL
}
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,