Add option to use torch._inductor.standalone_compile (#17057)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
@@ -50,7 +50,8 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
@@ -71,6 +72,10 @@ class CompilerInterface:
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
@@ -127,23 +132,108 @@ class AlwaysHitShapeEnv:
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> List[Any]:
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
name = "inductor_standalone"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config})
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
compiled_graph.save(path=path, format="unpacked")
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format="unpacked")
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5 and 2.6.
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
@@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
current_config = {}
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
current_config["max_autotune"] = True
|
||||
current_config["coordinate_descent_tuning"] = True
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
@@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = True
|
||||
config["coordinate_descent_tuning"] = True
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
@@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
|
||||
Reference in New Issue
Block a user