[Quant] Make static quant support all group shapes (#30833)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <optional>
|
||||
#include <torch/library.h>
|
||||
#include <tuple>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
@@ -346,8 +347,9 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
|
||||
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
@@ -4,28 +4,77 @@
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <tuple>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided(
|
||||
// STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
|
||||
// STRIDE_J_ZERO: true if scale_stride_j == 0 (per-tensor or per-token)
|
||||
template <typename scalar_t, typename fp8_type, bool STRIDE_I_ZERO,
|
||||
bool STRIDE_J_ZERO>
|
||||
__global__ void scaled_fp8_quant_kernel_strided_group_shape(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x; // one token per block
|
||||
int64_t out_row_stride, int group_m, int group_n, int64_t scale_stride_i,
|
||||
int64_t scale_stride_j) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float inv_scale = 1.0f / (*scale);
|
||||
// Precompute row-level base offset for scale access (compile-time eliminated
|
||||
// when STRIDE_I_ZERO)
|
||||
const int64_t scale_row_base =
|
||||
STRIDE_I_ZERO ? 0
|
||||
: static_cast<int>(token_idx) / group_m * scale_stride_i;
|
||||
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
auto get_inv_scale = [&](int gj) {
|
||||
return 1.0f / scale[scale_row_base + gj * scale_stride_j];
|
||||
};
|
||||
|
||||
int cached_gj = -1;
|
||||
float cached_inv_scale = 0.0f;
|
||||
auto get_inv_scale_cached = [&](int gj) {
|
||||
if (gj != cached_gj) {
|
||||
cached_inv_scale = 1.0f / scale[scale_row_base + gj * scale_stride_j];
|
||||
cached_gj = gj;
|
||||
}
|
||||
return cached_inv_scale;
|
||||
};
|
||||
|
||||
constexpr int VEC_SIZE = 16; // FP8 so vectorize to 128 bits
|
||||
auto scaled_fp8_conversion_vectorized = [&](const scalar_t* in, fp8_type* out,
|
||||
int size, float inv_scale) {
|
||||
vectorize_with_alignment<VEC_SIZE>(
|
||||
in, out, size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
};
|
||||
|
||||
if (STRIDE_J_ZERO && hidden_size % VEC_SIZE == 0) {
|
||||
// Per-tensor or per-token: single scale per row, vectorize full row
|
||||
scaled_fp8_conversion_vectorized(token_in, token_out, hidden_size,
|
||||
get_inv_scale(0));
|
||||
} else if (group_n % VEC_SIZE == 0) {
|
||||
// Multiple column groups with vectorization
|
||||
const int num_groups_n = hidden_size / group_n;
|
||||
|
||||
for (int gj = 0; gj < num_groups_n; gj++) {
|
||||
scaled_fp8_conversion_vectorized(token_in + gj * group_n,
|
||||
token_out + gj * group_n, group_n,
|
||||
get_inv_scale(gj));
|
||||
}
|
||||
} else {
|
||||
// Scalar path for small column groups (group_n < VEC_SIZE)
|
||||
for (int n = tid; n < hidden_size; n += blockDim.x) {
|
||||
const int gj = n / group_n;
|
||||
token_out[n] = scaled_fp8_conversion<true, fp8_type>(
|
||||
static_cast<float>(token_in[n]), get_inv_scale_cached(gj));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@@ -133,17 +182,116 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale, // various shapes
|
||||
std::optional<std::tuple<int64_t, int64_t>>
|
||||
opt_group_shape) // optional explicit (group_m, group_n)
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int hidden_size = input.size(-1); // N (columns)
|
||||
const int num_tokens = input.numel() / hidden_size; // M (rows)
|
||||
|
||||
// Determine group_m, group_n, and scale strides from scale shape
|
||||
// Scale indexing: scale[gi * scale_stride_j + gj * scale_stride_i]
|
||||
// where gi = m / group_m, gj = n / group_n
|
||||
int group_m, group_n;
|
||||
int64_t scale_stride_i, scale_stride_j;
|
||||
|
||||
if (scale.dim() == 0 || scale.numel() == 1) {
|
||||
// Per-tensor: one scale for the entire tensor
|
||||
group_m = num_tokens;
|
||||
group_n = hidden_size;
|
||||
scale_stride_i = 0;
|
||||
scale_stride_j = 0;
|
||||
} else if (scale.dim() == 1) {
|
||||
// 1D scale: require explicit group_shape to disambiguate per-channel vs
|
||||
// per-token (avoids edge case where num_tokens == hidden_size)
|
||||
TORCH_CHECK(opt_group_shape.has_value(),
|
||||
"1D scale requires explicit group_shape to disambiguate "
|
||||
"per-channel vs per-token quantization. "
|
||||
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
|
||||
"-1) for per-token.");
|
||||
|
||||
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
|
||||
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
|
||||
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
|
||||
|
||||
// Validate the explicit group shape matches the 1D scale
|
||||
const int64_t scale_len = scale.numel();
|
||||
const int64_t expected_scale_m = num_tokens / group_m;
|
||||
const int64_t expected_scale_n = hidden_size / group_n;
|
||||
const int64_t expected_scale_numel = expected_scale_m * expected_scale_n;
|
||||
|
||||
TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
|
||||
scale_len, ") does not match expected size (",
|
||||
expected_scale_numel, ") for group_shape (", opt_group_m, ", ",
|
||||
opt_group_n, ") with input shape (", num_tokens, ", ",
|
||||
hidden_size, ")");
|
||||
|
||||
// For 1D scale, determine strides based on which dim is trivial
|
||||
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
|
||||
// where gi = m / group_m (row group), gj = n / group_n (col group)
|
||||
if (expected_scale_m == 1) {
|
||||
// Per-channel style: one scale in M dim, scale varies along N
|
||||
// gi = 0 always, gj varies, so stride_1 traverses the scale
|
||||
scale_stride_i = 0;
|
||||
scale_stride_j = scale.stride(0);
|
||||
} else if (expected_scale_n == 1) {
|
||||
// Per-token style: one scale in N dim, scale varies along M
|
||||
// gj = 0 always, gi varies, so stride_0 traverses the scale
|
||||
scale_stride_i = scale.stride(0);
|
||||
scale_stride_j = 0;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"1D scale can only be used when one of the scale dimensions is 1. "
|
||||
"For 2D group scaling, use a 2D scale tensor.");
|
||||
}
|
||||
} else if (scale.dim() == 2) {
|
||||
// 2D scale: infer group sizes from scale dimensions (or use explicit if
|
||||
// provided)
|
||||
const int64_t scale_size_0 = scale.size(0);
|
||||
const int64_t scale_size_1 = scale.size(1);
|
||||
|
||||
TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
|
||||
") must be divisible by scale.size(0) (", scale_size_0, ")");
|
||||
TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size,
|
||||
") must be divisible by scale.size(1) (", scale_size_1, ")");
|
||||
|
||||
// Infer from 2D scale shape
|
||||
int inferred_group_m = num_tokens / scale_size_0;
|
||||
int inferred_group_n = hidden_size / scale_size_1;
|
||||
|
||||
// Use explicit if provided, otherwise use inferred
|
||||
if (opt_group_shape.has_value()) {
|
||||
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
|
||||
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
|
||||
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
|
||||
|
||||
// Validate explicit matches inferred
|
||||
TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n,
|
||||
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
|
||||
") does not match inferred group shape (", inferred_group_m,
|
||||
", ", inferred_group_n, ") from 2D scale tensor shape (",
|
||||
scale_size_0, ", ", scale_size_1, ")");
|
||||
} else {
|
||||
group_m = inferred_group_m;
|
||||
group_n = inferred_group_n;
|
||||
}
|
||||
|
||||
scale_stride_i = scale.stride(0);
|
||||
scale_stride_j = scale.stride(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
|
||||
scale.dim(), "D");
|
||||
}
|
||||
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
@@ -153,15 +301,23 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Dispatch to template-specialized kernel based on stride pattern
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
|
||||
VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided_group_shape<
|
||||
scalar_t, fp8_t, S0_ZERO, S1_ZERO>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride, group_m, group_n, scale_stride_i,
|
||||
scale_stride_j);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -599,9 +599,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||
|
||||
// Compute FP8 quantized tensor for given scaling factor.
|
||||
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
|
||||
// scaling. Optional group_m/group_n specify the group shape explicitly;
|
||||
// required for 1D scales to disambiguate per-channel vs per-token.
|
||||
ops.def(
|
||||
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
||||
"()");
|
||||
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
|
||||
"(int, int)? group_shape=None) -> ()");
|
||||
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||
|
||||
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||
|
||||
@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import (
|
||||
ref_dynamic_per_token_quant,
|
||||
)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_quantize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
@@ -21,10 +25,18 @@ SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(
|
||||
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
|
||||
output,
|
||||
input,
|
||||
scale=None,
|
||||
scale_ub=None,
|
||||
use_per_token_if_dynamic=False,
|
||||
group_shape=None,
|
||||
):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
opcheck(
|
||||
torch.ops._C.static_scaled_fp8_quant,
|
||||
(output, input, scale, group_shape),
|
||||
)
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(input.shape[0], 1), device=input.device, dtype=torch.float32
|
||||
@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None:
|
||||
ops_out = ops_out.to(dtype=dtype)
|
||||
|
||||
torch.testing.assert_close(ref_out, ops_out)
|
||||
|
||||
|
||||
# Test static FP8 quantization with 2D group scales
|
||||
GROUP_SHAPES_2D = [
|
||||
(-1, -1), # Per-tensor
|
||||
(-1, 1), # Per-channel
|
||||
(1, -1), # Per-token
|
||||
(-1, 128), # Per-head quantization
|
||||
(1, 128), # DeepSeek-style per-token-per-group (group_m=1, group_n=128)
|
||||
(128, 128), # DeepSeek-style block quantization
|
||||
(1, 64), # Smaller group size
|
||||
(1, 16), # Small group (scalar path in kernel)
|
||||
(4, 256), # Non-trivial both dimensions
|
||||
]
|
||||
# Use sizes divisible by all group shapes
|
||||
NUM_TOKENS_GROUP = [128, 512]
|
||||
HIDDEN_SIZES_GROUP = [256, 1024, 2048]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
|
||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES_2D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_static_fp8_quant_group_2d(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
group_shape: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
"""Test static FP8 quantization with 2D group scales using scaled_quantize."""
|
||||
# Normalize group_shape (-1 means full extent)
|
||||
norm_group_m = num_tokens if group_shape[0] == -1 else group_shape[0]
|
||||
norm_group_n = hidden_size if group_shape[1] == -1 else group_shape[1]
|
||||
|
||||
# Skip if sizes are not divisible by group shape
|
||||
if num_tokens % norm_group_m != 0 or hidden_size % norm_group_n != 0:
|
||||
pytest.skip(
|
||||
f"Skipping: ({num_tokens}, {hidden_size}) not divisible by "
|
||||
f"group_shape ({group_shape[0]}, {group_shape[1]})"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = scaled_quantize(
|
||||
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
|
||||
)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
|
||||
|
||||
torch.testing.assert_close(scale, ops_scale)
|
||||
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, scale=scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("group_shape", [(1, -1), (-1, 1)]) # per-token, per-channel
|
||||
@torch.inference_mode()
|
||||
def test_static_fp8_quant_1d_scale(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
group_shape: tuple[int, int],
|
||||
) -> None:
|
||||
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale_2d = scaled_quantize(
|
||||
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Flatten scale to 1D for testing 1D scale path
|
||||
scale_1d = scale_2d.flatten()
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(
|
||||
x, scale=scale_1d, group_shape=group_shape
|
||||
)
|
||||
|
||||
torch.testing.assert_close(scale_1d, ops_scale)
|
||||
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, scale=scale_1d, group_shape=group_shape)
|
||||
|
||||
@@ -1752,6 +1752,7 @@ def scaled_fp8_quant(
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
group_shape: tuple[int, int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
@@ -1763,14 +1764,23 @@ def scaled_fp8_quant(
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
|
||||
scale: Optional scaling factor for the FP8 quantization. Supports:
|
||||
- 0D or [1]: per-tensor scaling
|
||||
- 1D: requires explicit group_shape to disambiguate per-channel
|
||||
vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
|
||||
- 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
|
||||
DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
group_shape: Optional tuple (group_m, group_n) specifying the group
|
||||
shape for static quantization. Use -1 for "full extent" (e.g.,
|
||||
(-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
|
||||
Required for 1D scales; optional for 2D scales.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
@@ -1799,8 +1809,7 @@ def scaled_fp8_quant(
|
||||
scale = torch.empty(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
assert scale.numel() == 1, f"{scale.shape}"
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_fp8_min_max,
|
||||
group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||
@CustomOp.register("quant_fp8")
|
||||
class QuantFP8(CustomOp):
|
||||
"""
|
||||
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
|
||||
Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
@@ -57,14 +58,14 @@ class QuantFP8(CustomOp):
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
assert not static, "Group quantization only supports dynamic mode"
|
||||
self.group_size = group_shape.col
|
||||
else:
|
||||
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
|
||||
assert not static or group_shape == GroupShape.PER_TENSOR, (
|
||||
"Only per-tensor scales supported for static quantization."
|
||||
)
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
if not static:
|
||||
assert group_shape in (GroupShape.PER_TOKEN, GroupShape.PER_TENSOR), (
|
||||
"Only per-token or per-tensor scales are supported for dynamic "
|
||||
"non-group quantization."
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -72,8 +73,8 @@ class QuantFP8(CustomOp):
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
@@ -90,12 +91,14 @@ class QuantFP8(CustomOp):
|
||||
and self.group_shape == GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1
|
||||
)
|
||||
|
||||
return ops.scaled_fp8_quant(
|
||||
x,
|
||||
scale,
|
||||
num_token_padding=self.num_token_padding,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
group_shape=self.group_shape if self.static else None,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
@@ -131,8 +134,8 @@ class QuantFP8(CustomOp):
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
):
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
@@ -155,7 +158,10 @@ class QuantFP8(CustomOp):
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
out = x.to(torch.float32) * scale.reciprocal()
|
||||
out = (
|
||||
x.to(torch.float32)
|
||||
* group_broadcast(scale.to(torch.float32), x.shape[-2:]).reciprocal()
|
||||
)
|
||||
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
# This currently generates an extra Triton kernel in compilation.
|
||||
|
||||
@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
# If tensor has fewer dimensions than target shape, treat missing
|
||||
# dimensions as size 1 (standard PyTorch broadcasting behavior)
|
||||
t_dim_size = t.shape[i] if i < t.ndim else 1
|
||||
if t_dim_size != s and t_dim_size != 1:
|
||||
assert s % t_dim_size == 0
|
||||
t = (
|
||||
t.unsqueeze(i + 1)
|
||||
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
|
||||
.expand(*t.shape[: i + 1], s // t_dim_size, *t.shape[i + 1 :])
|
||||
.flatten(i, i + 1)
|
||||
)
|
||||
return t
|
||||
@@ -180,7 +183,16 @@ def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor to quantize
|
||||
group_shape: Shape of quantization groups
|
||||
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
|
||||
compute_dtype: Optional dtype for intermediate computations.
|
||||
If None, uses input dtype. Use torch.float32 for higher precision.
|
||||
"""
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, (
|
||||
"currently `scaled_quantize` only supports floating point dtypes "
|
||||
@@ -189,11 +201,14 @@ def scaled_quantize(
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
|
||||
# Convert to compute dtype if specified
|
||||
x_compute = x if compute_dtype is None else x.to(compute_dtype)
|
||||
|
||||
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
|
||||
assert x.ndim == 2
|
||||
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
|
||||
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
|
||||
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||
x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||
|
||||
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
|
||||
|
||||
Reference in New Issue
Block a user