[ROCm]: fix aiter rope functionalization (#35533)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
|
||||
rope_targets = [torch.ops._C.rotary_embedding.default]
|
||||
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
rope_targets.append(
|
||||
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
|
||||
)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
if at_target in rope_targets:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
Reference in New Issue
Block a user