[BugFix] Patch inductor memory plan logic (#26878)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,9 +3,9 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_torch_equal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -23,6 +23,72 @@ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
# see https://github.com/vllm-project/vllm/issues/10619
|
||||
torch._inductor.config.compile_threads = 1
|
||||
|
||||
# ===================================================
|
||||
# torch 2.9 Inductor PythonWrapperCodegen monkeypatch
|
||||
# ===================================================
|
||||
# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around
|
||||
# a test failure for test_multi_graph_piecewise_compile_outputs_equal.
|
||||
# For more context, see https://github.com/pytorch/pytorch/pull/165514.
|
||||
|
||||
|
||||
def memory_plan_reuse_patched(self):
|
||||
import torch._inductor.ir as ir
|
||||
from torch._inductor.codegen.wrapper import (
|
||||
EnterSubgraphLine,
|
||||
ExitSubgraphLine,
|
||||
MemoryPlanningLine,
|
||||
MemoryPlanningState,
|
||||
SubgraphPythonWrapperCodegen,
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
def get_output_names(graph_outputs) -> list[str]:
|
||||
import itertools
|
||||
|
||||
names = []
|
||||
shape_counter = itertools.count(0)
|
||||
none_counter = itertools.count(0)
|
||||
for node in graph_outputs:
|
||||
if isinstance(node, ir.NoneAsConstantBuffer):
|
||||
names.append(f"{V.graph.name}_none{next(none_counter)}")
|
||||
elif isinstance(node, ir.ShapeAsConstantBuffer):
|
||||
names.append(f"{V.graph.name}_shape{next(shape_counter)}")
|
||||
else:
|
||||
names.append(node.get_name())
|
||||
return names
|
||||
|
||||
if (
|
||||
isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen)
|
||||
and V.graph.wrapper_code.partition_signatures is not None
|
||||
):
|
||||
out_names = get_output_names(
|
||||
V.graph.wrapper_code.partition_signatures.output_nodes
|
||||
)
|
||||
else:
|
||||
out_names = V.graph.get_output_names()
|
||||
|
||||
while (
|
||||
self.lines
|
||||
and isinstance(self.lines[-1], MemoryPlanningLine)
|
||||
and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
|
||||
):
|
||||
# these lines will be pointless
|
||||
self.lines.pop()
|
||||
|
||||
# codegen allocations in two passes
|
||||
planning_states = [MemoryPlanningState()]
|
||||
past_planning_states = []
|
||||
for i in range(len(self.lines)):
|
||||
line = self.lines[i]
|
||||
if isinstance(line, MemoryPlanningLine):
|
||||
self.lines[i] = line.plan(planning_states[-1])
|
||||
elif isinstance(line, EnterSubgraphLine):
|
||||
planning_states.append(MemoryPlanningState())
|
||||
elif isinstance(line, ExitSubgraphLine):
|
||||
past_planning_states.append(planning_states.pop())
|
||||
past_planning_states.append(planning_states.pop())
|
||||
assert len(planning_states) == 0
|
||||
|
||||
|
||||
# ========================================
|
||||
# torch 2.9 Inductor Scheduler monkeypatch
|
||||
@@ -135,7 +201,9 @@ def _update_scheduler_patched(self) -> None:
|
||||
self.scheduler = Scheduler(self.operations)
|
||||
|
||||
|
||||
if version.parse(str(torch.__version__)) == version.parse("2.9.0"):
|
||||
if is_torch_equal("2.9.0"):
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||
|
||||
Reference in New Issue
Block a user