[BugFix] Fix de-functionalization pass for rotary_embedding (#23953)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2025-10-03 15:44:18 -07:00
committed by GitHub
parent b71fcd4905
commit 7cfa4b24bf
3 changed files with 263 additions and 84 deletions

View File

@@ -46,23 +46,43 @@ class FixFunctionalizationPass(VllmInductorPass):
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs['query']
mm_node = query.args[0].args[0]
key = kwargs['key']
getitem_nodes = self.getitem_users(node)
# rotary_embedding is a special case: the two mutating inputs
# are query and key, which are slices of mm_node.
# While functionalized, results at[1] and at[2] are scattered
# back into mm_node. After de-functionalization, we can just
# use mm_node directly.
for idx, user in self.getitem_users(node).items():
for user_of_getitem in user.users:
if is_func(user_of_getitem,
torch.ops.aten.slice_scatter.default):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
if (is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0],
torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users)):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
self.insert_defunctionalized(graph, node)
self._remove(node)
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(user_of_getitem,
torch.ops.aten.slice_scatter.default):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: 'query', 2: 'key'}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default: