[NVIDIA] Support SiluMul + NVFP4 quant fusion (#23671)

Signed-off-by: jindih <jindih@nvidia.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: jindih <jindih@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedic <lgovedic@redhat.com>
This commit is contained in:
elvischenv
2025-08-29 03:36:50 +08:00
committed by GitHub
parent 57d4ede520
commit 16a45b3a28
11 changed files with 746 additions and 64 deletions

View File

@@ -1,55 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
register_replacement)
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def silu_mul_pattern_static(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
return at2[1]
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
FUSED_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C,
"silu_and_mul_nvfp4_quant"):
FUSED_OPS[
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
result=result,
input=input,
scale=scale)
return at[1]
class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
quant_key: QuantKey,
):
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
assert self.quant_key in FUSED_OPS, \
f"unsupported fusion scheme {self.quant_key}"
self.FUSED_OP = FUSED_OPS[self.quant_key]
def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self, symmetric: bool = True):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(quant_key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale)
return at2[1]
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
scale=scale)
return at[1]
inputs = [
self.empty_quant(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp32(1, 1) # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
def empty_fp8(*args, **kwargs):
fp8 = current_platform.fp8_dtype()
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self):
super().__init__(kNvfp4Quant)
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(self.QUANT_OP,
output=result,
input=at1[1],
output_scale=output_scale,
input_scale=scale)
return at2[1], at2[2]
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
result_block_scale=output_scale,
input=input,
input_global_scale=scale)
return at[1], at[2]
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # result_silu_mul
empty_bf16(5, 64), # input
empty_fp32(1, 1) # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmInductorPass):
@@ -69,15 +168,11 @@ class ActivationQuantFusionPass(VllmInductorPass):
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass")
inputs = [
empty_fp8(5, 4), # Quant output
empty_bf16(5, 4), # Silu_and_mul output
empty_bf16(5, 4), # Input
empty_fp32(1, 1) # Scale
]
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
def __call__(self, graph: torch.fx.Graph):
self.begin()
@@ -89,3 +184,8 @@ class ActivationQuantFusionPass(VllmInductorPass):
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern)