[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import Callable, ClassVar, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
class QuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of quantization.
|
||||
dtype: quantized data type
|
||||
static: static quantization if True, dynamic if False
|
||||
per_tensor: per-tensor quantization if True, per-token if False
|
||||
group_shape: quantization group shape
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
|
||||
TODO(luka) use QuantDescriptor once standardized:
|
||||
https://github.com/vllm-project/vllm/issues/8913
|
||||
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
per_tensor: bool = True
|
||||
group_shape: GroupShape
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
|
||||
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"{'per_tensor' if self.per_tensor else 'per_token'},"
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
|
||||
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple):
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@@ -177,10 +206,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
per_tensor=True,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
@@ -233,10 +263,11 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
per_tensor=True,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
@@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
per_tensor: bool,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
per_tensor=per_tensor,
|
||||
group_shape=group_shape,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
@@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
per_tensor: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
per_tensor=per_tensor,
|
||||
group_shape=group_shape,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
@@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass):
|
||||
self.patterns, self.record_match)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
|
||||
per_tensor=False).register(
|
||||
self.patterns, self.record_match)
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE,
|
||||
per_tensor=False).register(
|
||||
self.patterns,
|
||||
self.record_match)
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
|
||||
Reference in New Issue
Block a user