[Compile] Conditional compilation. Introduce compile_ranges (#24252)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Ilya Markov
2025-12-05 19:17:32 +01:00
committed by GitHub
parent 66e674cdd5
commit 4e26d3b09e
15 changed files with 582 additions and 268 deletions

View File

@@ -26,7 +26,7 @@ from vllm.compilation.partition_rules import (
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import hash_factors
from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
@@ -90,7 +90,7 @@ class CompilerManager:
"""
def __init__(self, compilation_config: CompilationConfig):
self.cache: dict[tuple[int | None, int, str], Any] = dict()
self.cache: dict[tuple[Range, int, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
@@ -99,11 +99,11 @@ class CompilerManager:
return self.compiler.compute_hash(vllm_config)
@contextmanager
def compile_context(self, runtime_shape: int | None = None):
def compile_context(self, compile_range: Range):
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
with pass_context(compile_range):
if self.compilation_config.use_inductor_graph_partition:
with inductor_partition_rule_context(
self.compilation_config.splitting_ops
@@ -159,29 +159,21 @@ class CompilerManager:
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable | None:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, runtime_shape
handle, graph, example_inputs, graph_index, compile_range
)
logger.debug(
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
graph_index,
str(compile_range),
self.compiler.name,
handle,
)
if runtime_shape is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
handle,
)
return compiled_graph
def compile(
@@ -190,9 +182,9 @@ class CompilerManager:
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
compile_range: Range,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: int | None = None,
) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
@@ -204,7 +196,7 @@ class CompilerManager:
compiled_graph = None
# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
if compiled_graph is not None:
if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time.
@@ -212,19 +204,12 @@ class CompilerManager:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s",
elapsed,
)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s",
str(runtime_shape),
elapsed,
)
logger.info(
"Directly load the compiled graph(s) for compile range %s "
"from the cache, took %.3f s",
str(compile_range),
elapsed,
)
return compiled_graph
# no compiler cached the graph, or the cache is disabled,
@@ -233,14 +218,15 @@ class CompilerManager:
# Let compile_fx generate a key for us
maybe_key = None
else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
with self.compile_context(runtime_shape):
maybe_key = "artifact_compile_range_"
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
runtime_shape,
compile_range,
maybe_key,
)
@@ -248,55 +234,34 @@ class CompilerManager:
# store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
if graph_index == 0:
# adds some info logging for the first graph
if runtime_shape is None:
logger.info_once(
"Cache the graph for dynamic shape for later use", scope="local"
)
else:
logger.info_once(
"Cache the graph of shape %s for later use",
str(runtime_shape),
scope="local",
)
if runtime_shape is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
handle,
logger.info_once(
"Cache the graph of compile range %s for later use",
str(compile_range),
)
logger.debug(
"Store the %s-th graph for compile range%s from %s via handle %s",
graph_index,
str(compile_range),
self.compiler.name,
handle,
)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info_once(
"Compiling a graph for dynamic shape takes %.2f s",
elapsed,
scope="local",
)
else:
logger.info_once(
"Compiling a graph for shape %s takes %.2f s",
runtime_shape,
elapsed,
scope="local",
)
logger.info_once(
"Compiling a graph for compile range %s takes %.2f s",
str(compile_range),
elapsed,
scope="local",
)
return compiled_graph
@@ -427,19 +392,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
)
)
# Lazy import here to avoid circular import
from .piecewise_backend import PiecewiseBackend
@@ -449,7 +402,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend,
)
@@ -589,8 +541,13 @@ class VllmBackend:
)
else:
# Config should automatically wrap all inductor passes
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(self.inductor_config[self.pass_key])
assert isinstance(
self.compilation_config.inductor_compile_config[self.pass_key],
InductorPass,
)
self.pass_manager.add(
self.compilation_config.inductor_compile_config[self.pass_key]
)
self.inductor_config[self.pass_key] = self.pass_manager
def __call__(

View File

@@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
@@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
@@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
return compile_range.is_single_size() and compile_range.end % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
@@ -505,91 +506,60 @@ if flashinfer_comm is not None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_tensor_size = max_token_num * hidden_size * element_size
assert current_tensor_size <= max_tensor_size, (
f"Current tensor size {current_tensor_size} is larger than "
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
device_capability = current_platform.get_device_capability().to_int()
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
).get(world_size, None)
# Use one shot if no max size is specified
use_oneshot = (
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
if num_tokens <= max_token_num:
device_capability = current_platform.get_device_capability().to_int()
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
).get(world_size, None)
# Use one shot if no max size for one shot is specified
use_oneshot = (
max_one_shot_size_mb is None
or current_tensor_size <= max_one_shot_size_mb * MiB
)
assert _FI_WORKSPACE_TENSOR is not None, (
"Flashinfer must be enabled when using flashinfer"
)
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
# For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
token_num=allreduce_in.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
world_rank=world_rank,
world_size=world_size,
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
allreduce_out=None,
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
)
assert _FI_WORKSPACE_TENSOR is not None, (
"Flashinfer must be enabled when using flashinfer"
)
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if scale_factor is not None and scale_out is None:
# Do fused rms norm static fp8 quant fused op
if norm_out is None:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
quant_out,
allreduce_out,
residual,
rms_gamma,
scale_factor,
rms_eps,
)
else:
torch.ops._C.rms_norm_static_fp8_quant(
quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps
)
else:
if norm_out is None:
torch.ops._C.fused_add_rms_norm(
allreduce_out, residual, rms_gamma, rms_eps
)
norm_out = allreduce_out
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
if scale_factor is not None and scale_out is not None:
torch.ops._C.scaled_fp4_quant(
quant_out, norm_out, scale_out, scale_factor
)
if scale_factor is None or norm_out is not None:
# we need to return allreduce output
# in cases of non quant fused AR + RMS norm
# and fused AR + RMS norm + quant without fused add
allreduce_in.copy_(allreduce_out)
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
# For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
token_num=allreduce_in.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
world_rank=world_rank,
world_size=world_size,
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
allreduce_out=None,
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
@@ -1128,7 +1098,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s",
"Flashinfer allreduce fusion is not supported for world size %s"
" or max size is not provided",
self.tp_size,
)
return
@@ -1216,6 +1187,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
return compile_range.end <= self.max_token_num
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
if self.disabled:

View File

@@ -15,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -63,16 +64,16 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
with a range. The `compile_range` specifies the range of the inputs,
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
or a range [5, 8].
Right now we only support one variable in ranges for all inputs,
which is the batchsize (number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
@@ -98,7 +99,7 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
"""
Load the compiled function from the handle.
@@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)
set_inductor_config(current_config, compile_range)
set_functorch_config()
if isinstance(runtime_shape, int):
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_tracing_context"
dynamic_shapes = "from_graph"
from torch._inductor import standalone_compile
@@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
@@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
@@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
@@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, runtime_shape)
set_inductor_config(current_config, compile_range)
set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
@@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
@@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface):
return contextlib.nullcontext()
def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
def set_inductor_config(config, compile_range: Range):
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
@@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_eager_compiles += 1

View File

@@ -14,6 +14,7 @@ import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.config.utils import Range
from vllm.utils.torch_utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"):
@@ -28,8 +29,8 @@ _pass_context = None
class PassContext:
def __init__(self, runtime_shape: int | None):
self.runtime_shape = runtime_shape
def __init__(self, compile_range: Range):
self.compile_range: Range = compile_range
def get_pass_context() -> PassContext:
@@ -39,13 +40,13 @@ def get_pass_context() -> PassContext:
@contextmanager
def pass_context(runtime_shape: int | None):
def pass_context(compile_range: Range):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(runtime_shape)
_pass_context = PassContext(compile_range)
try:
yield
finally:
@@ -96,7 +97,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable(self, shape: int | None):
def is_applicable_for_range(self, compile_range: Range):
return True

View File

@@ -24,7 +24,11 @@ if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
@@ -70,13 +74,13 @@ class PostGradPassManager(CustomGraphPass):
def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable(shape):
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with shape %s", pass_, shape)
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
@@ -133,4 +137,8 @@ class PostGradPassManager(CustomGraphPass):
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
return InductorPass.hash_dict(state)

View File

@@ -7,18 +7,18 @@ from typing import Any
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
class RangeEntry:
compile_range: Range
compiled: bool = False
runnable: Callable = None # type: ignore
@@ -31,7 +31,6 @@ class PiecewiseBackend:
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend,
):
"""
@@ -55,67 +54,111 @@ class PiecewiseBackend:
self.is_full_graph = total_piecewise_compiles == 1
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
self.compile_ranges = self.compilation_config.get_compile_ranges()
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.first_run_finished = False
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# the entries for different shapes that we need to compile
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
for shape in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
runnable=self.compiled_graph_for_general_shape,
for size in self.compile_sizes:
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
def _fakify_args(self, args: list[Any]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
# matching their values.
# This is problem because it can lead to unintended specializations!
# if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
# torch.compile probably should not accept
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
runtime_shape = args[self.sym_shape_indices[0]]
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
if not range_entry.compiled:
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else args
)
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
self.check_for_ending_compilation()
return entry.runnable(*args)
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]"
)
self._maybe_compile_for_range_entry(range_entry, args)
return range_entry.runnable(*args)

View File

@@ -9,6 +9,7 @@ import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
@@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):