[Quant] Make static quant support all group shapes (#30833)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -4,28 +4,77 @@
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <tuple>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided(
|
||||
// STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
|
||||
// STRIDE_J_ZERO: true if scale_stride_j == 0 (per-tensor or per-token)
|
||||
template <typename scalar_t, typename fp8_type, bool STRIDE_I_ZERO,
|
||||
bool STRIDE_J_ZERO>
|
||||
__global__ void scaled_fp8_quant_kernel_strided_group_shape(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x; // one token per block
|
||||
int64_t out_row_stride, int group_m, int group_n, int64_t scale_stride_i,
|
||||
int64_t scale_stride_j) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float inv_scale = 1.0f / (*scale);
|
||||
// Precompute row-level base offset for scale access (compile-time eliminated
|
||||
// when STRIDE_I_ZERO)
|
||||
const int64_t scale_row_base =
|
||||
STRIDE_I_ZERO ? 0
|
||||
: static_cast<int>(token_idx) / group_m * scale_stride_i;
|
||||
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
auto get_inv_scale = [&](int gj) {
|
||||
return 1.0f / scale[scale_row_base + gj * scale_stride_j];
|
||||
};
|
||||
|
||||
int cached_gj = -1;
|
||||
float cached_inv_scale = 0.0f;
|
||||
auto get_inv_scale_cached = [&](int gj) {
|
||||
if (gj != cached_gj) {
|
||||
cached_inv_scale = 1.0f / scale[scale_row_base + gj * scale_stride_j];
|
||||
cached_gj = gj;
|
||||
}
|
||||
return cached_inv_scale;
|
||||
};
|
||||
|
||||
constexpr int VEC_SIZE = 16; // FP8 so vectorize to 128 bits
|
||||
auto scaled_fp8_conversion_vectorized = [&](const scalar_t* in, fp8_type* out,
|
||||
int size, float inv_scale) {
|
||||
vectorize_with_alignment<VEC_SIZE>(
|
||||
in, out, size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
};
|
||||
|
||||
if (STRIDE_J_ZERO && hidden_size % VEC_SIZE == 0) {
|
||||
// Per-tensor or per-token: single scale per row, vectorize full row
|
||||
scaled_fp8_conversion_vectorized(token_in, token_out, hidden_size,
|
||||
get_inv_scale(0));
|
||||
} else if (group_n % VEC_SIZE == 0) {
|
||||
// Multiple column groups with vectorization
|
||||
const int num_groups_n = hidden_size / group_n;
|
||||
|
||||
for (int gj = 0; gj < num_groups_n; gj++) {
|
||||
scaled_fp8_conversion_vectorized(token_in + gj * group_n,
|
||||
token_out + gj * group_n, group_n,
|
||||
get_inv_scale(gj));
|
||||
}
|
||||
} else {
|
||||
// Scalar path for small column groups (group_n < VEC_SIZE)
|
||||
for (int n = tid; n < hidden_size; n += blockDim.x) {
|
||||
const int gj = n / group_n;
|
||||
token_out[n] = scaled_fp8_conversion<true, fp8_type>(
|
||||
static_cast<float>(token_in[n]), get_inv_scale_cached(gj));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@@ -133,17 +182,116 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale, // various shapes
|
||||
std::optional<std::tuple<int64_t, int64_t>>
|
||||
opt_group_shape) // optional explicit (group_m, group_n)
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int hidden_size = input.size(-1); // N (columns)
|
||||
const int num_tokens = input.numel() / hidden_size; // M (rows)
|
||||
|
||||
// Determine group_m, group_n, and scale strides from scale shape
|
||||
// Scale indexing: scale[gi * scale_stride_j + gj * scale_stride_i]
|
||||
// where gi = m / group_m, gj = n / group_n
|
||||
int group_m, group_n;
|
||||
int64_t scale_stride_i, scale_stride_j;
|
||||
|
||||
if (scale.dim() == 0 || scale.numel() == 1) {
|
||||
// Per-tensor: one scale for the entire tensor
|
||||
group_m = num_tokens;
|
||||
group_n = hidden_size;
|
||||
scale_stride_i = 0;
|
||||
scale_stride_j = 0;
|
||||
} else if (scale.dim() == 1) {
|
||||
// 1D scale: require explicit group_shape to disambiguate per-channel vs
|
||||
// per-token (avoids edge case where num_tokens == hidden_size)
|
||||
TORCH_CHECK(opt_group_shape.has_value(),
|
||||
"1D scale requires explicit group_shape to disambiguate "
|
||||
"per-channel vs per-token quantization. "
|
||||
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
|
||||
"-1) for per-token.");
|
||||
|
||||
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
|
||||
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
|
||||
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
|
||||
|
||||
// Validate the explicit group shape matches the 1D scale
|
||||
const int64_t scale_len = scale.numel();
|
||||
const int64_t expected_scale_m = num_tokens / group_m;
|
||||
const int64_t expected_scale_n = hidden_size / group_n;
|
||||
const int64_t expected_scale_numel = expected_scale_m * expected_scale_n;
|
||||
|
||||
TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
|
||||
scale_len, ") does not match expected size (",
|
||||
expected_scale_numel, ") for group_shape (", opt_group_m, ", ",
|
||||
opt_group_n, ") with input shape (", num_tokens, ", ",
|
||||
hidden_size, ")");
|
||||
|
||||
// For 1D scale, determine strides based on which dim is trivial
|
||||
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
|
||||
// where gi = m / group_m (row group), gj = n / group_n (col group)
|
||||
if (expected_scale_m == 1) {
|
||||
// Per-channel style: one scale in M dim, scale varies along N
|
||||
// gi = 0 always, gj varies, so stride_1 traverses the scale
|
||||
scale_stride_i = 0;
|
||||
scale_stride_j = scale.stride(0);
|
||||
} else if (expected_scale_n == 1) {
|
||||
// Per-token style: one scale in N dim, scale varies along M
|
||||
// gj = 0 always, gi varies, so stride_0 traverses the scale
|
||||
scale_stride_i = scale.stride(0);
|
||||
scale_stride_j = 0;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"1D scale can only be used when one of the scale dimensions is 1. "
|
||||
"For 2D group scaling, use a 2D scale tensor.");
|
||||
}
|
||||
} else if (scale.dim() == 2) {
|
||||
// 2D scale: infer group sizes from scale dimensions (or use explicit if
|
||||
// provided)
|
||||
const int64_t scale_size_0 = scale.size(0);
|
||||
const int64_t scale_size_1 = scale.size(1);
|
||||
|
||||
TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
|
||||
") must be divisible by scale.size(0) (", scale_size_0, ")");
|
||||
TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size,
|
||||
") must be divisible by scale.size(1) (", scale_size_1, ")");
|
||||
|
||||
// Infer from 2D scale shape
|
||||
int inferred_group_m = num_tokens / scale_size_0;
|
||||
int inferred_group_n = hidden_size / scale_size_1;
|
||||
|
||||
// Use explicit if provided, otherwise use inferred
|
||||
if (opt_group_shape.has_value()) {
|
||||
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
|
||||
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
|
||||
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
|
||||
|
||||
// Validate explicit matches inferred
|
||||
TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n,
|
||||
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
|
||||
") does not match inferred group shape (", inferred_group_m,
|
||||
", ", inferred_group_n, ") from 2D scale tensor shape (",
|
||||
scale_size_0, ", ", scale_size_1, ")");
|
||||
} else {
|
||||
group_m = inferred_group_m;
|
||||
group_n = inferred_group_n;
|
||||
}
|
||||
|
||||
scale_stride_i = scale.stride(0);
|
||||
scale_stride_j = scale.stride(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
|
||||
scale.dim(), "D");
|
||||
}
|
||||
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
@@ -153,15 +301,23 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Dispatch to template-specialized kernel based on stride pattern
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
|
||||
VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided_group_shape<
|
||||
scalar_t, fp8_t, S0_ZERO, S1_ZERO>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride, group_m, group_n, scale_stride_i,
|
||||
scale_stride_j);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user