[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Union
|
||||
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.compilation.fx_utils import (find_specified_fn,
|
||||
find_specified_fn_maybe)
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
@@ -48,18 +49,19 @@ class TestBackend:
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
def check_before_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe, \
|
||||
ops_fully_replaced=True):
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
|
||||
for op in ops:
|
||||
find_fn(self.graph_pre_pass.nodes, op)
|
||||
if ops_fully_replaced:
|
||||
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||
if fully_replaced:
|
||||
assert num_post == 0, \
|
||||
f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe):
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
for op in ops:
|
||||
find_fn(self.graph_post_pass.nodes, op)
|
||||
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||
@@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(),
|
||||
ops_fully_replaced=False)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
|
||||
@@ -7,8 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, QuantKey)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
FusionPass, GroupShape, QuantKey)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
@@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
|
||||
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
group_shape=group_shape,
|
||||
symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
@@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
131
tests/compile/test_fusion_attn.py
Normal file
131
tests/compile/test_fusion_attn.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: Optional[TestBackend] = None
|
||||
backend_unfused: Optional[TestBackend] = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key",
|
||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize(
|
||||
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test CUDA and ROCm")
|
||||
def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
quant_key: QuantKey, use_triton_fa: bool):
|
||||
# Clean Dynamo cache to avoid reusing other test cases
|
||||
# (for some reason the reset at the end is not enough)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Use global backends
|
||||
global backend, backend_unfused
|
||||
|
||||
use_v1 = False # can be made a param once V1 support added
|
||||
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
|
||||
|
||||
# Prompt 4 seems too open-ended, differs between fused and unfused
|
||||
# (both outputs look reasonable though)
|
||||
prompts = example_prompts[:4] + example_prompts[5:]
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
|
||||
llm = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=10,
|
||||
top_p=0.95)
|
||||
|
||||
unfused_output = llm.generate(prompts, sampling_params)
|
||||
backend_unfused = None # Reset backend to make sure llm gets released
|
||||
del llm
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
|
||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||
llm2 = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048)
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape)
|
||||
for key, layer in compile_config.static_forward_context.items()
|
||||
]
|
||||
|
||||
print(f"{attn_fusion_supported=}")
|
||||
if any(attn_fusion_supported):
|
||||
# Check quant ops
|
||||
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
|
||||
# attention ops present in both, just output_scale param changes
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post)
|
||||
|
||||
for i in range(len(attn_nodes_pre)):
|
||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||
assert fused == attn_fusion_supported[i], \
|
||||
f"Node {i} {'' if fused else 'not '} expected " \
|
||||
f"to have fused output quant"
|
||||
|
||||
# check outputs
|
||||
fused_output = llm2.generate(prompts, sampling_params)
|
||||
|
||||
# transform outputs to format expected by check_outputs_equal
|
||||
sample_outs = lambda s: (list(s.token_ids), s.text)
|
||||
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outs_lst(unfused_output),
|
||||
outputs_1_lst=outs_lst(fused_output),
|
||||
name_0="unfused",
|
||||
name_1="fused",
|
||||
)
|
||||
|
||||
# Clean Dynamo cache to avoid polluting other case(s)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Reset backend to make sure llm2 gets released
|
||||
backend = None
|
||||
Reference in New Issue
Block a user