[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations (#10867)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
@@ -17,7 +18,6 @@ from .backend import TestBackend
|
||||
OPS_IN_MODEL = [
|
||||
torch.ops._C.rotary_embedding.default,
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.silu_and_mul.default,
|
||||
]
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
@@ -29,6 +29,9 @@ RMS_QUANT_OPS = {
|
||||
],
|
||||
}
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@@ -55,8 +58,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
enable_noop=True))
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
|
||||
] if do_fusion else [noop_pass]
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
@@ -79,6 +84,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
model_runner.model = torch.compile(orig_model,
|
||||
fullgraph=True,
|
||||
backend=backend_no_func)
|
||||
|
||||
gen_no_func = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output_func, output_no_func in zip(gen_func, gen_no_func):
|
||||
@@ -88,7 +94,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
||||
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
|
||||
] if do_fusion else [RMS_OP]
|
||||
ops = OPS_IN_MODEL + rms_ops
|
||||
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
|
||||
quant_key == kFp8StaticTensorSym else [
|
||||
SILU_MUL_OP
|
||||
]
|
||||
|
||||
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
|
||||
|
||||
for op in ops:
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
|
||||
Reference in New Issue
Block a user