[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
131
tests/compile/test_fusion_attn.py
Normal file
131
tests/compile/test_fusion_attn.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: Optional[TestBackend] = None
|
||||
backend_unfused: Optional[TestBackend] = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key",
|
||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize(
|
||||
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test CUDA and ROCm")
|
||||
def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
quant_key: QuantKey, use_triton_fa: bool):
|
||||
# Clean Dynamo cache to avoid reusing other test cases
|
||||
# (for some reason the reset at the end is not enough)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Use global backends
|
||||
global backend, backend_unfused
|
||||
|
||||
use_v1 = False # can be made a param once V1 support added
|
||||
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
|
||||
|
||||
# Prompt 4 seems too open-ended, differs between fused and unfused
|
||||
# (both outputs look reasonable though)
|
||||
prompts = example_prompts[:4] + example_prompts[5:]
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
|
||||
llm = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=10,
|
||||
top_p=0.95)
|
||||
|
||||
unfused_output = llm.generate(prompts, sampling_params)
|
||||
backend_unfused = None # Reset backend to make sure llm gets released
|
||||
del llm
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
|
||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||
llm2 = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048)
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape)
|
||||
for key, layer in compile_config.static_forward_context.items()
|
||||
]
|
||||
|
||||
print(f"{attn_fusion_supported=}")
|
||||
if any(attn_fusion_supported):
|
||||
# Check quant ops
|
||||
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
|
||||
# attention ops present in both, just output_scale param changes
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post)
|
||||
|
||||
for i in range(len(attn_nodes_pre)):
|
||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||
assert fused == attn_fusion_supported[i], \
|
||||
f"Node {i} {'' if fused else 'not '} expected " \
|
||||
f"to have fused output quant"
|
||||
|
||||
# check outputs
|
||||
fused_output = llm2.generate(prompts, sampling_params)
|
||||
|
||||
# transform outputs to format expected by check_outputs_equal
|
||||
sample_outs = lambda s: (list(s.token_ids), s.text)
|
||||
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outs_lst(unfused_output),
|
||||
outputs_1_lst=outs_lst(fused_output),
|
||||
name_0="unfused",
|
||||
name_1="fused",
|
||||
)
|
||||
|
||||
# Clean Dynamo cache to avoid polluting other case(s)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Reset backend to make sure llm2 gets released
|
||||
backend = None
|
||||
Reference in New Issue
Block a user