[Compile] Conditional compilation. Introduce compile_ranges (#24252)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: ProExpertProg <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ from vllm.compilation.partition_rules import (
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.config.utils import Range, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
from vllm.platforms import current_platform
|
||||
@@ -90,7 +90,7 @@ class CompilerManager:
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: dict[tuple[int | None, int, str], Any] = dict()
|
||||
self.cache: dict[tuple[Range, int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
@@ -99,11 +99,11 @@ class CompilerManager:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
@contextmanager
|
||||
def compile_context(self, runtime_shape: int | None = None):
|
||||
def compile_context(self, compile_range: Range):
|
||||
"""Provide compilation context for the duration of compilation to set
|
||||
any torch global properties we want to scope to a single Inductor
|
||||
compilation (e.g. partition rules, pass context)."""
|
||||
with pass_context(runtime_shape):
|
||||
with pass_context(compile_range):
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
with inductor_partition_rule_context(
|
||||
self.compilation_config.splitting_ops
|
||||
@@ -159,29 +159,21 @@ class CompilerManager:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable | None:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
if (compile_range, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, runtime_shape
|
||||
handle, graph, example_inputs, graph_index, compile_range
|
||||
)
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
|
||||
graph_index,
|
||||
str(compile_range),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
def compile(
|
||||
@@ -190,9 +182,9 @@ class CompilerManager:
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
compile_range: Range,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
@@ -204,7 +196,7 @@ class CompilerManager:
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
@@ -212,19 +204,12 @@ class CompilerManager:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for dynamic shape "
|
||||
"from the cache, took %.3f s",
|
||||
elapsed,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for shape %s "
|
||||
"from the cache, took %.3f s",
|
||||
str(runtime_shape),
|
||||
elapsed,
|
||||
)
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for compile range %s "
|
||||
"from the cache, took %.3f s",
|
||||
str(compile_range),
|
||||
elapsed,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
@@ -233,14 +218,15 @@ class CompilerManager:
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
|
||||
with self.compile_context(runtime_shape):
|
||||
maybe_key = "artifact_compile_range_"
|
||||
maybe_key += f"{compile_range.start}_{compile_range.end}"
|
||||
maybe_key += f"_subgraph_{graph_index}"
|
||||
with self.compile_context(compile_range):
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
runtime_shape,
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
|
||||
@@ -248,55 +234,34 @@ class CompilerManager:
|
||||
|
||||
# store the artifact in the cache
|
||||
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
|
||||
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info_once(
|
||||
"Cache the graph for dynamic shape for later use", scope="local"
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Cache the graph of shape %s for later use",
|
||||
str(runtime_shape),
|
||||
scope="local",
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
logger.info_once(
|
||||
"Cache the graph of compile range %s for later use",
|
||||
str(compile_range),
|
||||
)
|
||||
logger.debug(
|
||||
"Store the %s-th graph for compile range%s from %s via handle %s",
|
||||
graph_index,
|
||||
str(compile_range),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info_once(
|
||||
"Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape,
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
logger.info_once(
|
||||
"Compiling a graph for compile range %s takes %.2f s",
|
||||
str(compile_range),
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
@@ -427,19 +392,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
|
||||
compiled_graph_for_dynamic_shape = (
|
||||
self.vllm_backend.compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
)
|
||||
)
|
||||
# Lazy import here to avoid circular import
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
@@ -449,7 +402,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
index,
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape,
|
||||
self.vllm_backend,
|
||||
)
|
||||
|
||||
@@ -589,8 +541,13 @@ class VllmBackend:
|
||||
)
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
|
||||
self.pass_manager.add(self.inductor_config[self.pass_key])
|
||||
assert isinstance(
|
||||
self.compilation_config.inductor_compile_config[self.pass_key],
|
||||
InductorPass,
|
||||
)
|
||||
self.pass_manager.add(
|
||||
self.compilation_config.inductor_compile_config[self.pass_key]
|
||||
)
|
||||
self.inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
@@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
return compile_range.is_single_size() and compile_range.end % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
@@ -505,91 +506,60 @@ if flashinfer_comm is not None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_tensor_size = max_token_num * hidden_size * element_size
|
||||
assert current_tensor_size <= max_tensor_size, (
|
||||
f"Current tensor size {current_tensor_size} is larger than "
|
||||
f"max token num {max_token_num} * hidden size {hidden_size} * "
|
||||
f"element size {element_size}"
|
||||
)
|
||||
device_capability = current_platform.get_device_capability().to_int()
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, {}
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
if num_tokens <= max_token_num:
|
||||
device_capability = current_platform.get_device_capability().to_int()
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, {}
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size for one shot is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size_mb is None
|
||||
or current_tensor_size <= max_one_shot_size_mb * MiB
|
||||
)
|
||||
|
||||
assert _FI_WORKSPACE_TENSOR is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
)
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
# For the sizes that are smaller than the max size,
|
||||
# we only use flashinfer one shot allreduce
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=allreduce_in,
|
||||
token_num=allreduce_in.shape[0],
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
world_rank=world_rank,
|
||||
world_size=world_size,
|
||||
hidden_dim=allreduce_in.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=pattern_code,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
assert _FI_WORKSPACE_TENSOR is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
)
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
||||
if scale_factor is not None and scale_out is None:
|
||||
# Do fused rms norm static fp8 quant fused op
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
quant_out,
|
||||
allreduce_out,
|
||||
residual,
|
||||
rms_gamma,
|
||||
scale_factor,
|
||||
rms_eps,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.rms_norm_static_fp8_quant(
|
||||
quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps
|
||||
)
|
||||
else:
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm(
|
||||
allreduce_out, residual, rms_gamma, rms_eps
|
||||
)
|
||||
norm_out = allreduce_out
|
||||
else:
|
||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
|
||||
if scale_factor is not None and scale_out is not None:
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
quant_out, norm_out, scale_out, scale_factor
|
||||
)
|
||||
if scale_factor is None or norm_out is not None:
|
||||
# we need to return allreduce output
|
||||
# in cases of non quant fused AR + RMS norm
|
||||
# and fused AR + RMS norm + quant without fused add
|
||||
allreduce_in.copy_(allreduce_out)
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
# For the sizes that are smaller than the max size,
|
||||
# we only use flashinfer one shot allreduce
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=allreduce_in,
|
||||
token_num=allreduce_in.shape[0],
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
world_rank=world_rank,
|
||||
world_size=world_size,
|
||||
hidden_dim=allreduce_in.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=pattern_code,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
allreduce_in: torch.Tensor,
|
||||
@@ -1128,7 +1098,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
if max_size is None:
|
||||
# Flashinfer doesn't support current world size
|
||||
logger.warning(
|
||||
"Flashinfer allreduce fusion is not supported for world size %s",
|
||||
"Flashinfer allreduce fusion is not supported for world size %s"
|
||||
" or max size is not provided",
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
@@ -1216,6 +1187,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
|
||||
self.disabled = False
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
if self.disabled:
|
||||
|
||||
@@ -15,6 +15,7 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
@@ -63,16 +64,16 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
|
||||
or a range [5, 8].
|
||||
Right now we only support one variable in ranges for all inputs,
|
||||
which is the batchsize (number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
@@ -98,7 +99,7 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
@@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
dynamic_shapes = "from_graph"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
@@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
@@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
@@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
@@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
@@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
@@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
def set_inductor_config(config, compile_range: Range):
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
@@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
|
||||
@@ -14,6 +14,7 @@ import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.config.utils import Range
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
@@ -28,8 +29,8 @@ _pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
def __init__(self, runtime_shape: int | None):
|
||||
self.runtime_shape = runtime_shape
|
||||
def __init__(self, compile_range: Range):
|
||||
self.compile_range: Range = compile_range
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
@@ -39,13 +40,13 @@ def get_pass_context() -> PassContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: int | None):
|
||||
def pass_context(compile_range: Range):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
_pass_context = PassContext(compile_range)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -96,7 +97,7 @@ class InductorPass(CustomGraphPass):
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable(self, shape: int | None):
|
||||
def is_applicable_for_range(self, compile_range: Range):
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,11 @@ if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -70,13 +74,13 @@ class PostGradPassManager(CustomGraphPass):
|
||||
def __call__(self, graph: fx.Graph):
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
compile_range = get_pass_context().compile_range
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable(shape):
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with shape %s", pass_, shape)
|
||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
@@ -133,4 +137,8 @@ class PostGradPassManager(CustomGraphPass):
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
# recompiles the graph for the new dynamic compile range.
|
||||
state["compile_range"] = str(get_pass_context().compile_range)
|
||||
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
@@ -7,18 +7,18 @@ from typing import Any
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
class RangeEntry:
|
||||
compile_range: Range
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
|
||||
@@ -31,7 +31,6 @@ class PiecewiseBackend:
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend,
|
||||
):
|
||||
"""
|
||||
@@ -55,67 +54,111 @@ class PiecewiseBackend:
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
self.compile_sizes = self.compilation_config.compile_sizes
|
||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
|
||||
# the entries for different shapes that we need to compile
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for shape in self.compile_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
runnable=self.compiled_graph_for_general_shape,
|
||||
for size in self.compile_sizes:
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
def _fakify_args(self, args: list[Any]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
# will fakify the example_inputs potentially causing some non dynamic
|
||||
# dimension to be be duck shaped to other existing shapes that have hints
|
||||
# matching their values.
|
||||
# This is problem because it can lead to unintended specializations!
|
||||
# if the new wrongly dynamic dim is specialized
|
||||
# it will force specializing the whole shape
|
||||
# torch.compile probably should not accept
|
||||
# non fake tensors as example inputs!
|
||||
# See issue https://github.com/vllm-project/vllm/issues/27899
|
||||
fake_example_inputs = []
|
||||
for node in self.graph.graph.nodes:
|
||||
# All place holders come first
|
||||
if node.op == "placeholder":
|
||||
fake_example_inputs.append(node.meta["example_value"])
|
||||
else:
|
||||
break
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
|
||||
if not range_entry.compiled:
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else args
|
||||
)
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
return entry.runnable(*args)
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
|
||||
else:
|
||||
for range in self.compile_ranges:
|
||||
if runtime_shape in range:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
assert range_entry is not None, (
|
||||
f"Shape out of considered range: {runtime_shape} "
|
||||
"[1, max_num_batched_tokens]"
|
||||
)
|
||||
|
||||
self._maybe_compile_for_range_entry(range_entry, args)
|
||||
return range_entry.runnable(*args)
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
||||
# needs to be split along the sequence dimension. However, this dimension
|
||||
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
||||
@@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
|
||||
Reference in New Issue
Block a user