[torch.compile] Dynamic fp8 + rms_norm fusion (#10906)

Signed-off-by: luka <luka@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Luka Govedič
2024-12-12 22:19:23 -05:00
committed by GitHub
parent 78ed8f57d8
commit 30870b4f66
20 changed files with 1735 additions and 251 deletions

View File

@@ -4,10 +4,10 @@ import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.compilation.vllm_inductor_pass import is_func
from vllm.config import CompilationConfig
from .backend import TestBackend
@@ -35,12 +35,16 @@ prompts = [
]
@pytest.mark.parametrize("model",
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
@pytest.mark.parametrize(
"model, quant_key",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
kFp8DynamicTokenSym)])
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fix_functionalization(model: str, do_fusion: bool):
def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
@@ -78,8 +82,9 @@ def test_fix_functionalization(model: str, do_fusion: bool):
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
if do_fusion else [RMS_OP])
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
ops = OPS_IN_MODEL + rms_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)