[torch.compile] Rename compile_ranges_split_points to compile_ranges_endpoints (#36027)

Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Copilot
2026-03-09 18:04:40 +00:00
committed by GitHub
parent fa028207aa
commit 4b87ffbefb
5 changed files with 30 additions and 31 deletions

View File

@@ -46,10 +46,10 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Get the compile ranges split points after vllm config post init # Get the compile ranges endpoints after vllm config post init
# in order to compute compile ranges correctly # in order to compute compile ranges correctly
compilation_config.compile_ranges_split_points = ( compilation_config.compile_ranges_endpoints = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
) )

View File

@@ -86,7 +86,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
), ),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32], compile_ranges_endpoints=[8, 32],
compile_sizes=[16, 64, 128], compile_sizes=[16, 64, 128],
inductor_compile_config={ inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker, "post_grad_custom_post_pass": post_grad_range_checker,
@@ -110,7 +110,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
def test_compile_config_get_compile_ranges(): def test_compile_config_get_compile_ranges():
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
compile_ranges_split_points=[8, 32], compile_ranges_endpoints=[8, 32],
) )
VllmConfig( VllmConfig(
scheduler_config=SchedulerConfig( scheduler_config=SchedulerConfig(
@@ -149,7 +149,7 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8], compile_ranges_endpoints=[8],
inductor_compile_config={ inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker, "post_grad_custom_post_pass": post_grad_range_checker,
}, },

View File

@@ -885,8 +885,8 @@ class VllmBackend:
"splitting_ops": list_to_str(cc.splitting_ops), "splitting_ops": list_to_str(cc.splitting_ops),
"cudagraph_mode": str(cc.cudagraph_mode), "cudagraph_mode": str(cc.cudagraph_mode),
"compile_sizes": list_to_str(cc.compile_sizes), "compile_sizes": list_to_str(cc.compile_sizes),
"compile_ranges_split_points": list_to_str( "compile_ranges_endpoints": list_to_str(
cc.compile_ranges_split_points cc.compile_ranges_endpoints
), ),
"use_inductor_graph_partition": cc.use_inductor_graph_partition, "use_inductor_graph_partition": cc.use_inductor_graph_partition,
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())), "inductor_passes": list_to_str(list(cc.inductor_passes.keys())),

View File

@@ -382,8 +382,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs] [vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation: - Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`] - [`compile_ranges_endpoints`]
[vllm.config.CompilationConfig.compile_ranges_split_points] [vllm.config.CompilationConfig.compile_ranges_endpoints]
- [`inductor_compile_config`] - [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config] [vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
@@ -492,12 +492,12 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.""" specify the sizes for cudagraph capture."""
compile_ranges_split_points: list[int] | None = None compile_ranges_endpoints: list[int] | None = None
"""Split points that represent compile ranges for inductor. """Endpoints for Inductor compile ranges.
The compile ranges are The compile ranges are
[1, split_points[0]], [1, endpoints[0]],
[split_points[0] + 1, split_points[1]], ..., [endpoints[0] + 1, endpoints[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens]. [endpoints[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges, Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]]. the range is represented as [compile_sizes[i], compile_sizes[i]].
@@ -1246,10 +1246,9 @@ class CompilationConfig:
def get_compile_ranges(self) -> list[Range]: def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config.""" """Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None: if self.compile_ranges_endpoints is None:
return [] return []
split_points = sorted(set(self.compile_ranges_split_points)) endpoints = sorted(set(self.compile_ranges_endpoints))
return [ return [
Range(start=s + 1, end=e) Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
for s, e in zip([0] + split_points[:-1], split_points)
] ]

View File

@@ -1451,12 +1451,12 @@ class VllmConfig:
Set the compile ranges for the compilation config. Set the compile ranges for the compilation config.
""" """
compilation_config = self.compilation_config compilation_config = self.compilation_config
computed_compile_ranges_split_points = [] computed_compile_ranges_endpoints = []
# The upper bound of the compile ranges is the max_num_batched_tokens. # The upper bound of the compile ranges is the max_num_batched_tokens.
compile_range_end = self.scheduler_config.max_num_batched_tokens compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None: if compile_range_end is not None:
computed_compile_ranges_split_points.append(compile_range_end) computed_compile_ranges_endpoints.append(compile_range_end)
# Add the compile ranges for flashinfer # Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms: if compilation_config.pass_config.fuse_allreduce_rms:
@@ -1468,7 +1468,7 @@ class VllmConfig:
* self.model_config.dtype.itemsize * self.model_config.dtype.itemsize
) )
if compile_range_end is not None and max_token_num < compile_range_end: if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num) computed_compile_ranges_endpoints.append(max_token_num)
else: else:
logger.debug( logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, " "Max num batched tokens below allreduce-rms fusion threshold, "
@@ -1500,10 +1500,10 @@ class VllmConfig:
and min_token_num < max_num_batched_tokens and min_token_num < max_num_batched_tokens
and min_token_num > 1 and min_token_num > 1
): ):
# Add split point at min_token_num - 1 to ensure SP applies # Add endpoint at min_token_num - 1 to ensure SP applies
# starting from min_token_num # starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies) # This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1) computed_compile_ranges_endpoints.append(min_token_num - 1)
if compilation_config.pass_config.fuse_rope_kvcache: if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = ( max_token_num = (
@@ -1511,7 +1511,7 @@ class VllmConfig:
) )
if max_token_num is not None: if max_token_num is not None:
if compile_range_end is not None and max_token_num < compile_range_end: if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num) computed_compile_ranges_endpoints.append(max_token_num)
else: else:
logger.debug( logger.debug(
"Max num batched tokens below rope+kvcache fusion threshold, " "Max num batched tokens below rope+kvcache fusion threshold, "
@@ -1519,14 +1519,14 @@ class VllmConfig:
compile_range_end, compile_range_end,
) )
if compilation_config.compile_ranges_split_points is not None: if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_split_points: for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int) assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}" assert x > 0, f"Invalid compile range endpoint: {x}"
if compile_range_end is not None and x < compile_range_end and x > 1: if compile_range_end is not None and x < compile_range_end and x > 1:
computed_compile_ranges_split_points.append(x) computed_compile_ranges_endpoints.append(x)
compilation_config.compile_ranges_split_points = sorted( compilation_config.compile_ranges_endpoints = sorted(
computed_compile_ranges_split_points computed_compile_ranges_endpoints
) )
def try_verify_and_update_config(self): def try_verify_and_update_config(self):