|
|
|
|
@@ -31,8 +31,12 @@
|
|
|
|
|
|
|
|
|
|
namespace vllm {
|
|
|
|
|
|
|
|
|
|
// NVFP4 quantization kernel for experts (low-latency path).
|
|
|
|
|
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
|
|
|
|
|
// SiLU(gate)*up before quantization.
|
|
|
|
|
// Use UE4M3 by default.
|
|
|
|
|
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
|
|
|
|
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
|
|
|
|
|
bool SMALL_NUM_EXPERTS = false>
|
|
|
|
|
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|
|
|
|
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
|
|
|
|
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
|
|
|
|
@@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|
|
|
|
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
|
|
|
|
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
|
|
|
|
|
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
|
|
|
|
|
|
|
|
|
|
// Each global thread processes one element
|
|
|
|
|
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
|
|
|
|
@@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|
|
|
|
int rowIdx = globalIdx / colsPerRow;
|
|
|
|
|
int colIdx = globalIdx % colsPerRow;
|
|
|
|
|
|
|
|
|
|
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
|
|
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
|
|
|
// Get the output tensor offset.
|
|
|
|
|
// Same as inOffset because 8 elements are packed into one uint32_t.
|
|
|
|
|
int64_t outOffset = inOffset;
|
|
|
|
|
auto& out_pos = out[outOffset];
|
|
|
|
|
|
|
|
|
|
// Find index within the experts using different strategies based on expert
|
|
|
|
|
// count
|
|
|
|
|
int rowIdx_in_expert = 0;
|
|
|
|
|
@@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load input and optionally apply fused SiLU+Mul
|
|
|
|
|
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
|
|
|
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
|
|
|
PackedVec quant_input;
|
|
|
|
|
if constexpr (FUSE_SILU_MUL) {
|
|
|
|
|
PackedVec in_vec_up =
|
|
|
|
|
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
|
|
|
|
quant_input = compute_silu_mul(in_vec, in_vec_up);
|
|
|
|
|
} else {
|
|
|
|
|
quant_input = in_vec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get the output tensor offset.
|
|
|
|
|
// Same as inOffset because 8 elements are packed into one uint32_t.
|
|
|
|
|
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
|
|
|
|
auto& out_pos = out[outOffset];
|
|
|
|
|
|
|
|
|
|
// Get the global scaling factor, which will be applied to the SF.
|
|
|
|
|
// Note SFScale is the same as next GEMM's alpha, which is
|
|
|
|
|
// (448.f / (Alpha_A / 6.f)).
|
|
|
|
|
@@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|
|
|
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
|
|
|
|
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
|
|
|
|
|
|
|
|
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
|
|
|
|
out_pos =
|
|
|
|
|
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
|
|
|
|
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
|
|
|
|
// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized
|
|
|
|
|
// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and
|
|
|
|
|
// fuses SiLU(gate)*up before quantization.
|
|
|
|
|
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
|
|
|
|
|
bool SMALL_NUM_EXPERTS = false>
|
|
|
|
|
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|
|
|
|
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
|
|
|
|
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
|
|
|
|
@@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|
|
|
|
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
|
|
|
|
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
|
|
|
|
|
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
|
|
|
|
|
|
|
|
|
|
// Each global thread processes one element
|
|
|
|
|
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
|
|
|
|
@@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|
|
|
|
int rowIdx = globalIdx / colsPerRow;
|
|
|
|
|
int colIdx = globalIdx % colsPerRow;
|
|
|
|
|
|
|
|
|
|
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
|
|
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
|
|
|
int64_t outOffset = inOffset;
|
|
|
|
|
auto& out_pos = out[outOffset];
|
|
|
|
|
|
|
|
|
|
// Find expert using binary search for better performance with large m_topk
|
|
|
|
|
int rowIdx_in_expert = 0;
|
|
|
|
|
int expert_idx = 0;
|
|
|
|
|
@@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load input and optionally apply fused SiLU+Mul
|
|
|
|
|
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
|
|
|
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
|
|
|
PackedVec quant_input;
|
|
|
|
|
if constexpr (FUSE_SILU_MUL) {
|
|
|
|
|
PackedVec in_vec_up =
|
|
|
|
|
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
|
|
|
|
quant_input = compute_silu_mul(in_vec, in_vec_up);
|
|
|
|
|
} else {
|
|
|
|
|
quant_input = in_vec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
|
|
|
|
auto& out_pos = out[outOffset];
|
|
|
|
|
|
|
|
|
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
|
|
|
|
|
|
|
|
|
uint32_t* SFout_in_expert =
|
|
|
|
|
@@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|
|
|
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
|
|
|
|
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
|
|
|
|
|
|
|
|
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
|
|
|
|
out_pos =
|
|
|
|
|
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename T, bool FUSE_SILU_MUL = false>
|
|
|
|
|
void quant_impl(void* output, void* output_scale, void* input,
|
|
|
|
|
void* input_global_scale, void* input_offset_by_experts,
|
|
|
|
|
void* output_scale_offset_by_experts, int m_topk, int k,
|
|
|
|
|
@@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input,
|
|
|
|
|
if (blockRepeat > 1) {
|
|
|
|
|
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
|
|
|
|
if (n_experts >= 4) {
|
|
|
|
|
cvt_fp16_to_fp4<T, false, false>
|
|
|
|
|
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
|
|
|
|
|
<<<grid, block, shared_mem_size, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
@@ -256,34 +289,37 @@ void quant_impl(void* output, void* output_scale, void* input,
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts);
|
|
|
|
|
} else {
|
|
|
|
|
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts);
|
|
|
|
|
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
|
|
|
|
|
<<<grid, block, shared_mem_size, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (n_experts >= 16) {
|
|
|
|
|
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts, /* bool low_latency */ true);
|
|
|
|
|
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
|
|
|
|
|
<<<grid, block, 0, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts, /* bool low_latency */ true);
|
|
|
|
|
} else {
|
|
|
|
|
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts, /* bool low_latency */ true);
|
|
|
|
|
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
|
|
|
|
|
<<<grid, block, 0, stream>>>(
|
|
|
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
|
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
|
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
|
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
|
|
|
n_experts, /* bool low_latency */ true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float;
|
|
|
|
|
constexpr auto INT = at::ScalarType::Int;
|
|
|
|
|
constexpr auto UINT8 = at::ScalarType::Byte;
|
|
|
|
|
|
|
|
|
|
void scaled_fp4_experts_quant_sm1xxa(
|
|
|
|
|
torch::Tensor& output, torch::Tensor& output_scale,
|
|
|
|
|
// Common validation for fp4 experts quantization entry points.
|
|
|
|
|
static void validate_fp4_experts_quant_inputs(
|
|
|
|
|
torch::Tensor const& output, torch::Tensor const& output_scale,
|
|
|
|
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
|
|
|
torch::Tensor const& input_offset_by_experts,
|
|
|
|
|
torch::Tensor const& output_scale_offset_by_experts) {
|
|
|
|
|
CHECK_INPUT(output, "output must be a CUDA tensor");
|
|
|
|
|
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
|
|
|
|
CHECK_INPUT(input, "input must be a CUDA tensor");
|
|
|
|
|
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
|
|
|
|
CHECK_INPUT(input_offset_by_experts,
|
|
|
|
|
"input_offset_by_experts must be a CUDA tensor");
|
|
|
|
|
CHECK_INPUT(output_scale_offset_by_experts,
|
|
|
|
|
"output_scale_offset_by_experts must be a CUDA tensor");
|
|
|
|
|
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
|
|
|
|
int64_t k) {
|
|
|
|
|
CHECK_INPUT(output, "output");
|
|
|
|
|
CHECK_INPUT(output_scale, "output_scale");
|
|
|
|
|
CHECK_INPUT(input, "input");
|
|
|
|
|
CHECK_INPUT(input_global_scale, "input_global_scale");
|
|
|
|
|
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
|
|
|
|
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
|
|
|
|
|
|
|
|
|
TORCH_CHECK(output.dim() == 2);
|
|
|
|
|
TORCH_CHECK(output_scale.dim() == 2);
|
|
|
|
|
@@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa(
|
|
|
|
|
TORCH_CHECK(output_scale.scalar_type() == INT);
|
|
|
|
|
|
|
|
|
|
const int BLOCK_SIZE = 16;
|
|
|
|
|
auto m_topk = input.size(0);
|
|
|
|
|
auto k = input.size(1);
|
|
|
|
|
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
|
|
|
|
auto n_experts = input_global_scale.size(0);
|
|
|
|
|
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
|
|
|
|
@@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa(
|
|
|
|
|
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
|
|
|
|
// 4 means 4 fp8 values are packed into one int32
|
|
|
|
|
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void scaled_fp4_experts_quant_sm1xxa(
|
|
|
|
|
torch::Tensor& output, torch::Tensor& output_scale,
|
|
|
|
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
|
|
|
torch::Tensor const& input_offset_by_experts,
|
|
|
|
|
torch::Tensor const& output_scale_offset_by_experts) {
|
|
|
|
|
auto m_topk = input.size(0);
|
|
|
|
|
auto k = input.size(1);
|
|
|
|
|
|
|
|
|
|
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
|
|
|
|
input_global_scale, input_offset_by_experts,
|
|
|
|
|
output_scale_offset_by_experts, m_topk, k);
|
|
|
|
|
|
|
|
|
|
auto n_experts = input_global_scale.size(0);
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
|
const cudaStream_t stream =
|
|
|
|
|
at::cuda::getCurrentCUDAStream(input.get_device());
|
|
|
|
|
@@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa(
|
|
|
|
|
VLLM_DISPATCH_HALF_TYPES(
|
|
|
|
|
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
|
|
|
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
|
|
|
|
vllm::quant_impl<cuda_type>(
|
|
|
|
|
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
|
|
|
|
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
|
|
|
|
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
|
|
|
|
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
|
|
|
|
stream);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
|
|
|
|
torch::Tensor& output, torch::Tensor& output_scale,
|
|
|
|
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
|
|
|
torch::Tensor const& input_offset_by_experts,
|
|
|
|
|
torch::Tensor const& output_scale_offset_by_experts) {
|
|
|
|
|
auto m_topk = input.size(0);
|
|
|
|
|
// Input has gate || up layout, so k = input.size(1) / 2
|
|
|
|
|
auto k_times_2 = input.size(1);
|
|
|
|
|
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
|
|
|
|
auto k = k_times_2 / 2;
|
|
|
|
|
|
|
|
|
|
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
|
|
|
|
input_global_scale, input_offset_by_experts,
|
|
|
|
|
output_scale_offset_by_experts, m_topk, k);
|
|
|
|
|
|
|
|
|
|
auto n_experts = input_global_scale.size(0);
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
|
const cudaStream_t stream =
|
|
|
|
|
at::cuda::getCurrentCUDAStream(input.get_device());
|
|
|
|
|
|
|
|
|
|
VLLM_DISPATCH_HALF_TYPES(
|
|
|
|
|
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
|
|
|
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
|
|
|
|
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
|
|
|
|
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
|
|
|
|
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
|
|
|
|
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
|
|
|
|
|