Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: elvircrn <elvircrn@gmail.com>
Signed-off-by: Elvir Crnčević <elvircrn@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
This commit is contained in:
Elvir Crnčević
2025-10-10 17:19:53 +02:00
committed by GitHub
parent ae9d0e7da5
commit 7b03584de8
6 changed files with 519 additions and 405 deletions

View File

@@ -138,12 +138,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void silu_mul_fp8_quant_deep_gemm_cuda(
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
bool use_ue8m0);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

View File

@@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel(
}
__device__ __forceinline__ float silu(float x) {
return (__fdividef(x, (1.f + expf(-x))));
return __fdividef(x, (1.f + expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) {
#ifndef USE_ROCM
return make_bfloat162(__float2bfloat16_rn(silu(x.x)),
__float2bfloat16_rn(silu(x.y)));
#else
return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y)));
#endif
}
#ifndef USE_ROCM
__device__ __forceinline__ float warp_max(float v) {
static constexpr unsigned FULL_MASK = 0xffffffffu;
@@ -223,224 +232,308 @@ constexpr __nv_bfloat16 get_fp8_min() {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
}
}
#ifndef USE_ROCM
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
template <typename Idx_t>
__device__ __forceinline__ int warp_expert_search(
int idx, int n, const Idx_t* __restrict__ input, Idx_t val) {
const Idx_t* input_ptr = input + idx;
int base_offset = 0;
for (;;) {
bool move_on = (idx < n && *input_ptr <= val);
unsigned mask = __ballot_sync(0xffffffff, move_on);
if (mask != 0xffffffffu) {
int last_lane = 31 - __clz(mask);
return base_offset + last_lane;
}
input_ptr += 32;
base_offset += 32;
idx += 32;
}
}
template <int num_parallel_tokens>
__device__ __forceinline__ void token_bounds(int32_t n_tokens,
int32_t worker_id,
int32_t& n_tokens_lower,
int32_t& n_tokens_upper) {
if (n_tokens < num_parallel_tokens && worker_id < n_tokens) {
if (worker_id >= num_parallel_tokens) return;
n_tokens_lower = worker_id;
n_tokens_upper = worker_id + 1;
} else {
int32_t chunk_size = n_tokens / num_parallel_tokens;
int32_t residual = n_tokens - chunk_size * num_parallel_tokens;
auto calc_id = [&](int32_t id) {
if (id < residual)
return min(n_tokens, id * (chunk_size + 1));
else
return min(n_tokens, id * chunk_size + residual);
};
n_tokens_lower = calc_id(worker_id);
n_tokens_upper = calc_id(worker_id + 1);
}
}
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
int NUM_STAGES = 3>
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
// sizes
int H, int G,
Idx_t E, Idx_t T, Idx_t H,
// strides (in elements)
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
Idx_t stride_ys_g, Idx_t stride_counts_e) {
#ifndef USE_ROCM
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8;
static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE;
static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4;
static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES;
extern __shared__ __align__(16) __int128_t smem_128[];
int* s_expert_offsets =
reinterpret_cast<int*>(smem_128 + (SMEM_SIZE_BYTES_Y / 16));
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
// We assign EPS with it's 16-bit unsigned counterpart to allow constexpr.
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
int tid = threadIdx.x;
int warp_id = tid >> 5;
int lane_id = tid & 0x1f;
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
int running_sum{};
if (!warp_id) {
for (int i = 0; i < E; i += WARP_SIZE) {
bool valid = (i + threadIdx.x) < E;
int value =
(valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) +
(!lane_id ? running_sum : 0);
// We split the shared memory in half, corresponding to gate and up matrices:
// [...gate_i, ...up_i] where 0 <= i < stages.
static constexpr int32_t S_NUM_128 =
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
for (int offset = 1; offset < 32; offset *= 2) {
int n = __shfl_up_sync(0xFFFFFFFFu, value, offset);
if (lane_id >= offset) value += n;
}
const int32_t tid = threadIdx.x;
const int32_t warp_id = tid / WARP_SIZE;
const int32_t lane_id = tid % WARP_SIZE;
if (valid) {
s_expert_offsets[i + threadIdx.x + 1] = value;
}
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1);
}
// block handles one (expert e, group g)
int32_t pid = blockIdx.x;
int32_t e = pid / G;
int32_t g = pid % G;
const int32_t n_tokens = counts[e * stride_counts_e];
if (!n_tokens) {
return; // Exit ASAP.
if (!lane_id) {
s_expert_offsets[0] = 0;
}
}
const Idx_t stride_i_t_128 = stride_i_t / 8u;
__syncthreads();
int32_t n_tokens_lower, n_tokens_upper;
int32_t total_tokens = s_expert_offsets[E];
const int warp_position_yq = warp_id * (H / NUM_WARPS);
const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS));
// A single block will handle tokens_per_block tokens.
// Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
// Specialize this, but can be likely fused.
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
return;
}
n_tokens_lower = blockIdx.y;
n_tokens_upper = blockIdx.y + 1;
} else {
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
auto calc_id = [&](int32_t id) {
if (id < residual) {
return min(n_tokens, id * (chunk_size + 1));
} else {
return min(n_tokens, id * chunk_size + residual);
}
};
n_tokens_lower = calc_id(blockIdx.y);
n_tokens_upper = calc_id(blockIdx.y + 1);
}
if (n_tokens_lower >= n_tokens_upper) {
// Each warp will get space to store its hidden dim for gate and up.
__int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES);
__int128_t* smem_load_ptr = s_hidden_load + lane_id;
const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max);
int32_t compute_pipeline_offset_64 = 0;
int32_t load_stage_offset{};
const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f);
__int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) +
warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) +
lane_id;
__int64_t* s_gate64_ptr = smem_compute_ptr;
__int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4;
int tokens_lower, tokens_upper;
token_bounds<BLOCK_COUNT>(total_tokens, blockIdx.x, tokens_lower,
tokens_upper);
Idx_t expert_id{}, expert_offset{}, next_expert_offset{};
int token_id = tokens_lower;
int32_t t_load{};
if (token_id < tokens_upper) {
expert_id = warp_expert_search<int>(lane_id, E, s_expert_offsets, token_id);
expert_offset = s_expert_offsets[expert_id];
next_expert_offset = s_expert_offsets[expert_id + 1];
} else {
// This thread block has no work to do.
return;
}
// We do calculations here, using constexpr wherever possible.
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
const Idx_t base_yq =
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
stride_i_t_128 * n_tokens_lower;
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
auto y_s_ptr =
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
stride_yq_t * n_tokens_lower + 4 * lane_id;
int32_t t_load = n_tokens_lower, load_stage_id = 0;
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
int32_t stage_offset{};
int t_load_bound = H / (GROUP_SIZE * NUM_WARPS);
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
static constexpr int32_t LOAD_STAGE_MOD =
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
Idx_t base_i = ((expert_id * stride_i_e) / 8) +
(token_id - expert_offset) * stride_i_t / 8;
const Idx_t gate_warp_offset =
warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111);
const __int128_t* input_128_ptr =
reinterpret_cast<const __int128_t*>(_input) + gate_warp_offset +
((lane_id < 16) ? 0 : ((H * stride_i_h) / 8));
__int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
auto token_offset = token_id - expert_offset;
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
auto load_and_advance_y_pred = [&] {
if (t_load < n_tokens_upper) {
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
if (t_load < t_load_bound) {
// Here we are simply continuing to load data
// from the current token.
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
stage_offset += LOAD_STAGE_SIZE;
stage_offset %= LOAD_STAGE_MOD;
load_stage_offset += LOAD_STAGE_SIZE;
load_stage_offset %= LOAD_STAGE_MOD;
if (tid < HALF_THREAD_COUNT) {
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
gate_128_ptr += stride_i_t_128;
} else {
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
up_128_ptr += stride_i_t_128;
}
cp_async4(smem_load_ptr_staged, load_ptr);
load_ptr += GROUP_SIZE / 8;
++t_load;
} else if (token_id + 1 < tokens_upper) {
// We loaded everything from the current token, let's move on
// to the next one, and we checked that we have more tokens to load.
++token_id;
t_load = 0;
if (token_id >= next_expert_offset) {
// We need to find the next expert.
do {
// This is a loop because it's possible
// that some experts are assigned 0 tokens.
// NOTE: We are guaranteed that there's at least
// one more token left so we don't have to check for
// expert_id bounds.
++expert_id;
// This skips 1 memory read.
expert_offset = next_expert_offset;
next_expert_offset = s_expert_offsets[expert_id + 1];
} while (next_expert_offset == expert_offset);
base_i = expert_id * (stride_i_e / 8);
token_offset = 0;
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
} else {
// We remain within the same expert, so just
// move by H/4 __int128_t (2 * H/8).
base_i += stride_yq_t / 4;
token_offset++;
}
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
load_stage_offset += LOAD_STAGE_SIZE;
load_stage_offset %= LOAD_STAGE_MOD;
cp_async4(smem_load_ptr_staged, load_ptr);
load_ptr += GROUP_SIZE / 8;
++t_load;
++load_stage_id;
}
// We fence even if there is nothing to load to simplify pipelining.
cp_async_fence();
};
// We need to warm-up the pipeline.
#pragma unroll
for (int i = 0; i < NUM_STAGES - 1; i++) {
load_and_advance_y_pred();
}
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
lane_id;
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
__nv_fp8x4_e4m3* y_q_base_ptr =
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
for (auto j = tokens_lower; j < tokens_upper; j++) {
const Idx_t base_ys = expert_id * stride_ys_e;
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
__nv_fp8x4_e4m3* y_q_ptr =
y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t +
warp_position_yq * stride_yq_h) /
4;
const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS);
int32_t compute_pipeline_offset_64 = 0;
for (int i = 0; i < COMPUTE_LIMIT; i++) {
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
load_and_advance_y_pred();
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
__nv_bfloat162 results_bf162[2];
__int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64;
__int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64;
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
// COMPUTE_STAGE_SIZE/MOD must also be constexpr!
compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE;
compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD;
// We double-buffer pipelined loads so that the next load will
// concurrently run with compute without overwrites.
load_and_advance_y_pred();
__int64_t gate64 = *gate64_ptr;
__int64_t up64 = *up64_ptr;
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
// STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64 += STAGE_SIZE;
compute_pipeline_offset_64 %= STAGE_MOD;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t gate64 = *s_gate_compute_64;
__nv_bfloat162* s_gate_compute_32 =
reinterpret_cast<__nv_bfloat162*>(&gate64);
__int64_t up64 = *s_up_compute_64;
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
// Compute
__nv_bfloat162 res[2];
__nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64);
__nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64);
#pragma unroll
for (int i = 0; i < 2; i++) {
// For silu, we make sure that div is emitted.
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
results_bf162[i] = __float22bfloat162_rn(gate);
}
for (int32_t k = 0; k < 2; ++k) {
__nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k]));
res[k] = __hmul2(gate, s_up_comp[k]);
}
auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1]));
_y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS);
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
if constexpr (USE_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
__nv_bfloat16 inv_y = __hdiv(one_bf16, y_s);
auto y_s2 = make_bfloat162(inv_y, inv_y);
#pragma unroll
for (int i = 0; i < 2; i++) {
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
}
for (int32_t k = 0; k < 2; ++k) {
res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min),
__bfloat162bfloat162(fp8_max));
}
auto _y_max2 =
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
*y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]);
y_q_ptr += WARP_SIZE * stride_yq_h;
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
// An entire group is assigned to a single warp, so a simple warp reduce
// is used.
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
if constexpr (USE_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
auto y_s2 = make_bfloat162(inv_y, inv_y);
#pragma unroll
for (int32_t i = 0; i < 2; ++i) {
results_bf162[i] =
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
__bfloat162bfloat162(fp8_max));
}
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
y_q_ptr += stride_yq_t;
if (lane_id == 0) {
*y_s_ptr = y_s;
y_s_ptr += stride_ys_t;
if (!lane_id) {
*y_s_ptr = y_s;
y_s_ptr += stride_ys_g;
}
}
}
}
#endif
}
} // namespace vllm
@@ -475,14 +568,14 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& tokens_per_expert, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0) {
#ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
TORCH_CHECK(input.dtype() == torch::kBFloat16);
@@ -491,10 +584,6 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
using Idx_t = int64_t;
Idx_t E = input.size(0);
@@ -510,81 +599,54 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_counts_e = counts.stride(0);
Idx_t stride_counts_e = tokens_per_expert.stride(0);
static constexpr int GROUP_SIZE = 128;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t G;
dim3 block, grid;
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
G = H / Idx_t(group_size * num_warps);
grid = dim3(E * G, _num_parallel_tokens);
block = dim3(num_warps * WARP_SIZE);
};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
"silu_mul_fp8_quant_deep_gemm_kernel",
[&] { KERNEL_CALL_TOP_LEVEL });
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
USE_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
stride_ys_g, stride_counts_e); \
});
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
if (!use_ue8m0) {
if (H >= 4096) {
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096) {
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}
}
#endif
}

View File

@@ -33,11 +33,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
ops.def(
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s, int group_size, "
"bool use_ue8m0, int num_parallel_tokens) -> ()");
ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
&silu_mul_fp8_quant_deep_gemm_cuda);
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s,"
"bool use_ue8m0) -> ()");
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
&persistent_masked_m_silu_mul_quant);
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);