// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright contributors to the vLLM project #include #include #include "../../dispatch_utils.h" #include "quant_conversions.cuh" #include "../w8a8/fp8/common.cuh" namespace vllm { // Logic: one thread block per (token, group) pair template __global__ void silu_and_mul_per_block_quant_kernel( scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in // FP8/INT8 float* __restrict__ scales, // Output: [num_tokens, hidden_size / // group_size] or [hidden_size / group_size, // num_tokens] scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2] float const* scale_ub, // Optional scale upper bound int32_t const hidden_size // Output hidden size (input is 2x this) ) { static_assert((group_size & (group_size - 1)) == 0, "group_size must be a power of 2 for correct reduction"); // Grid: (num_tokens, num_groups) int const token_idx = blockIdx.x; int const group_idx = blockIdx.y; int const tid = threadIdx.x; // tid in [0, group_size) int const num_tokens = gridDim.x; // Input layout: [gate || up] concatenated along last dimension int const input_stride = hidden_size * 2; int const group_start = group_idx * group_size; // Pointers to this token's data scalar_t const* token_input_gate = input + token_idx * input_stride + group_start; scalar_t const* token_input_up = token_input_gate + hidden_size; scalar_out_t* token_output = out + token_idx * hidden_size + group_start; // Scale pointer for this group int const num_groups = gridDim.y; float* group_scale_ptr = is_scale_transposed ? scales + group_idx * num_tokens + token_idx : scales + token_idx * num_groups + group_idx; // Shared memory for reduction (compile-time sized) __shared__ float shared_max[group_size]; // Step 1: Each thread loads one element, computes SiLU, stores in register float gate = static_cast(token_input_gate[tid]); float up = static_cast(token_input_up[tid]); // Compute SiLU(gate) * up float sigmoid_gate = 1.0f / (1.0f + expf(-gate)); float silu_gate = gate * sigmoid_gate; float result = silu_gate * up; // Keep in register // Step 2: Reduce to find group max shared_max[tid] = fabsf(result); __syncthreads(); // Power-of-2 reduction (group_size guaranteed to be power of 2) #pragma unroll for (int stride = group_size / 2; stride > 0; stride >>= 1) { if (tid < stride) { shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]); } __syncthreads(); } // Step 3: Compute scale (thread 0), broadcast via shared memory if (tid == 0) { float group_max = shared_max[0]; float const quant_range = quant_type_max_v; float group_scale = group_max / quant_range; // Apply scale upper bound if provided if (scale_ub != nullptr) { group_scale = fminf(group_scale, *scale_ub); } // Use minimum safe scaling factor group_scale = fmaxf(group_scale, min_scaling_factor::val()); // Store scale to global memory *group_scale_ptr = group_scale; // Reuse shared_max[0] to broadcast scale shared_max[0] = group_scale; } __syncthreads(); float group_scale = shared_max[0]; // Step 4: Quantize and write output token_output[tid] = vllm::ScaledQuant::quant_fn(result, group_scale); } } // namespace vllm void silu_and_mul_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, int64_t group_size, std::optional scale_ub, bool is_scale_transposed) { static c10::ScalarType kFp8Type = is_fp8_ocp() ? c10::ScalarType::Float8_e4m3fn : c10::ScalarType::Float8_e4m3fnuz; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK( input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16, "Input must be FP16 or BF16"); TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32"); TORCH_CHECK(group_size == 128 || group_size == 64, "Unsupported group size: ", group_size); if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); } int32_t hidden_size = out.size(-1); auto num_tokens = input.size(0); int32_t num_groups = hidden_size / group_size; TORCH_CHECK(input.size(-1) == hidden_size * 2, "input last dim must be 2x output hidden_size"); TORCH_CHECK(hidden_size % group_size == 0, "hidden_size must be divisible by group_size"); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(num_tokens, num_groups); dim3 block(group_size); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "silu_and_mul_per_block_quant", [&] { using scalar_in_t = scalar_t; VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "silu_and_mul_per_block_quant", [&] { using scalar_out_t = scalar_t; VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { vllm::silu_and_mul_per_block_quant_kernel< scalar_in_t, scalar_out_t, transpose_scale, gs> <<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, hidden_size); }); }); }); }); }