[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations (#10867)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2025-05-01 07:59:28 -07:00
committed by GitHub
parent 28566d73b3
commit 460a2b1100
11 changed files with 406 additions and 9 deletions

View File

@@ -5,6 +5,7 @@ import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
@@ -17,7 +18,6 @@ from .backend import TestBackend
OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
]
RMS_OP = torch.ops._C.rms_norm.default
@@ -29,6 +29,9 @@ RMS_QUANT_OPS = {
],
}
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
prompts = [
"Hello, my name is",
"The president of the United States is",
@@ -55,8 +58,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
@@ -79,6 +84,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)
for output_func, output_no_func in zip(gen_func, gen_no_func):
@@ -88,7 +94,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
ops = OPS_IN_MODEL + rms_ops
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)