[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Union
|
||||
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.compilation.fx_utils import (find_specified_fn,
|
||||
find_specified_fn_maybe)
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
@@ -48,18 +49,19 @@ class TestBackend:
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
def check_before_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe, \
|
||||
ops_fully_replaced=True):
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
|
||||
for op in ops:
|
||||
find_fn(self.graph_pre_pass.nodes, op)
|
||||
if ops_fully_replaced:
|
||||
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||
if fully_replaced:
|
||||
assert num_post == 0, \
|
||||
f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe):
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
for op in ops:
|
||||
find_fn(self.graph_post_pass.nodes, op)
|
||||
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||
Reference in New Issue
Block a user