[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)
Signed-off-by: Luka Govedic <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
@@ -58,7 +75,7 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
@@ -99,7 +116,7 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
group_shape: Optional[GroupShape] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
@@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
|
||||
Reference in New Issue
Block a user