[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Luka Govedič
2025-07-11 00:56:28 -04:00
committed by GitHub
parent 35514b682a
commit 31d5c1797f
18 changed files with 368 additions and 104 deletions

View File

@@ -3,7 +3,7 @@
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Optional
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
@@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
@@ -58,7 +75,7 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: tuple[int, int],
group_shape: GroupShape,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
@@ -99,7 +116,7 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[tuple[int, int]] = None,
group_shape: Optional[GroupShape] = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
@@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor,
)
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,