[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-06-12 11:31:04 -04:00
committed by GitHub
parent 96846bb360
commit f98548b9da
33 changed files with 622 additions and 79 deletions

View File

@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
from torch import fx
from torch._ops import OpOverload
from vllm.compilation.fx_utils import (find_specified_fn,
find_specified_fn_maybe)
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
@@ -48,18 +49,19 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph
self.final_graph = graph
def check_before_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
for op in ops:
find_fn(self.graph_pre_pass.nodes, op)
if ops_fully_replaced:
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops:
find_fn(self.graph_post_pass.nodes, op)
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"