[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Luka Govedič
2025-07-11 00:56:28 -04:00
committed by GitHub
parent 35514b682a
commit 31d5c1797f
18 changed files with 368 additions and 104 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional
import torch
import torch._inductor.pattern_matcher as pm
@@ -11,6 +11,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
@@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.