[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -12,7 +12,8 @@ from torch._ops import OpOverload
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import find_getitem_maybe
|
||||
@@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
@@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
class QuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of quantization.
|
||||
dtype: quantized data type
|
||||
static: static quantization if True, dynamic if False
|
||||
group_shape: quantization group shape
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
|
||||
TODO(luka) use QuantDescriptor once standardized:
|
||||
https://github.com/vllm-project/vllm/issues/8913
|
||||
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
group_shape: GroupShape
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
|
||||
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
|
||||
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
@@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
@@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
@@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
@@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user