[Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (#29631)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-12-03 13:15:50 +08:00
committed by GitHub
parent b08025a83b
commit c719c40540
2 changed files with 14 additions and 2 deletions

View File

@@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.

View File

@@ -75,8 +75,8 @@ def find_op_nodes(
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op: