[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -2,16 +2,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
@@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScaleDesc:
|
||||
"""
|
||||
Class for describing a single quantization scaling factor.
|
||||
dtype: data type of the scale
|
||||
static: static scale if True, dynamic if False
|
||||
group_shape: group shape of the scale
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
group_shape: GroupShape
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
|
||||
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"{'static' if self.static else 'dynamic'},{group_shape}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuantKey:
|
||||
"""
|
||||
Class for identifying the type of quantization.
|
||||
dtype: quantized data type
|
||||
scale: scale descriptor
|
||||
scale2: second-level scale descriptor
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
scale: ScaleDesc
|
||||
scale2: Optional[ScaleDesc] = None
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
|
||||
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"scale({self.scale}),{scale2_str}"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
|
||||
|
||||
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
|
||||
|
||||
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
|
||||
|
||||
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
|
||||
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
|
||||
kNvfp4Quant = QuantKey(FP4_DTYPE,
|
||||
scale=kNvfp4GroupScale,
|
||||
scale2=kStaticTensorScale)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
|
||||
Reference in New Issue
Block a user