[ROCm]: fix aiter rope functionalization (#35533)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-02-27 16:42:30 -06:00
committed by GitHub
parent 9fa6c68fa6
commit e3691988d0

View File

@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
rope_targets = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
rope_targets.append(
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
if at_target in rope_targets:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)