[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-08-23 06:09:05 +08:00
committed by GitHub
parent cc7ae5e7ca
commit 24d0c9e6ed
27 changed files with 596 additions and 200 deletions

View File

@@ -2,16 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from torch import fx
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
@@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
@dataclass(frozen=True)
class ScaleDesc:
"""
Class for describing a single quantization scaling factor.
dtype: data type of the scale
static: static scale if True, dynamic if False
group_shape: group shape of the scale
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}")
@dataclass(frozen=True)
class QuantKey:
"""
Class for identifying the type of quantization.
dtype: quantized data type
scale: scale descriptor
scale2: second-level scale descriptor
symmetric: symmetric if True, asymmetric if False
"""
dtype: torch.dtype
scale: ScaleDesc
scale2: Optional[ScaleDesc] = None
symmetric: bool = True
def __str__(self):
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
f"scale({self.scale}),{scale2_str}"
f"{'a' if not self.symmetric else ''}symmetric)")
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE,
scale=kNvfp4GroupScale,
scale2=kStaticTensorScale)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent