[Compile] Conditional compilation. Introduce compile_ranges (#24252)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Ilya Markov
2025-12-05 19:17:32 +01:00
committed by GitHub
parent 66e674cdd5
commit 4e26d3b09e
15 changed files with 582 additions and 268 deletions

View File

@@ -15,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -63,16 +64,16 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
with a range. The `compile_range` specifies the range of the inputs,
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
or a range [5, 8].
Right now we only support one variable in ranges for all inputs,
which is the batchsize (number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
@@ -98,7 +99,7 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
"""
Load the compiled function from the handle.
@@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)
set_inductor_config(current_config, compile_range)
set_functorch_config()
if isinstance(runtime_shape, int):
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_tracing_context"
dynamic_shapes = "from_graph"
from torch._inductor import standalone_compile
@@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
@@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
@@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
@@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, runtime_shape)
set_inductor_config(current_config, compile_range)
set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
@@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
@@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface):
return contextlib.nullcontext()
def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
def set_inductor_config(config, compile_range: Range):
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
@@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_eager_compiles += 1