[torch.compile] Dynamic fp8 + rms_norm fusion (#10906)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
@@ -4,10 +4,10 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||
from vllm.compilation.vllm_inductor_pass import is_func
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
@@ -35,12 +35,16 @@ prompts = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model",
|
||||
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key",
|
||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
|
||||
kFp8DynamicTokenSym)])
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model: str, do_fusion: bool):
|
||||
def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
do_fusion: bool):
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
@@ -78,8 +82,9 @@ def test_fix_functionalization(model: str, do_fusion: bool):
|
||||
|
||||
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
|
||||
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
||||
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
|
||||
if do_fusion else [RMS_OP])
|
||||
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
|
||||
] if do_fusion else [RMS_OP]
|
||||
ops = OPS_IN_MODEL + rms_ops
|
||||
|
||||
for op in ops:
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
|
||||
Reference in New Issue
Block a user