From 4b87ffbefb3881a0a33f9c1cb7121429bddad666 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:04:40 +0000 Subject: [PATCH] [torch.compile] Rename `compile_ranges_split_points` to `compile_ranges_endpoints` (#36027) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič --- tests/compile/fusions_e2e/conftest.py | 6 +++--- tests/compile/test_compile_ranges.py | 6 +++--- vllm/compilation/backends.py | 4 ++-- vllm/config/compilation.py | 21 ++++++++++----------- vllm/config/vllm.py | 24 ++++++++++++------------ 5 files changed, 30 insertions(+), 31 deletions(-) diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index d083b6f14..29eb84251 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -46,10 +46,10 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - # Get the compile ranges split points after vllm config post init + # Get the compile ranges endpoints after vllm config post init # in order to compute compile ranges correctly - compilation_config.compile_ranges_split_points = ( - llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points + compilation_config.compile_ranges_endpoints = ( + llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints ) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 430db850c..286ed4a8b 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -86,7 +86,7 @@ def test_compile_ranges(use_fresh_inductor_cache): ), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - compile_ranges_split_points=[8, 32], + compile_ranges_endpoints=[8, 32], compile_sizes=[16, 64, 128], inductor_compile_config={ "post_grad_custom_post_pass": post_grad_range_checker, @@ -110,7 +110,7 @@ def test_compile_ranges(use_fresh_inductor_cache): def test_compile_config_get_compile_ranges(): compilation_config = CompilationConfig( - compile_ranges_split_points=[8, 32], + compile_ranges_endpoints=[8, 32], ) VllmConfig( scheduler_config=SchedulerConfig( @@ -149,7 +149,7 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache): scheduler_config=scheduler_config, compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - compile_ranges_split_points=[8], + compile_ranges_endpoints=[8], inductor_compile_config={ "post_grad_custom_post_pass": post_grad_range_checker, }, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 6325d91a1..c0c46d9e7 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -885,8 +885,8 @@ class VllmBackend: "splitting_ops": list_to_str(cc.splitting_ops), "cudagraph_mode": str(cc.cudagraph_mode), "compile_sizes": list_to_str(cc.compile_sizes), - "compile_ranges_split_points": list_to_str( - cc.compile_ranges_split_points + "compile_ranges_endpoints": list_to_str( + cc.compile_ranges_endpoints ), "use_inductor_graph_partition": cc.use_inductor_graph_partition, "inductor_passes": list_to_str(list(cc.inductor_passes.keys())), diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index bf91fda95..b829c31e7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -382,8 +382,8 @@ class CompilationConfig: [vllm.config.CompilationConfig.cudagraph_copy_inputs] - Inductor compilation: - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] - - [`compile_ranges_split_points`] - [vllm.config.CompilationConfig.compile_ranges_split_points] + - [`compile_ranges_endpoints`] + [vllm.config.CompilationConfig.compile_ranges_endpoints] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -492,12 +492,12 @@ class CompilationConfig: to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" - compile_ranges_split_points: list[int] | None = None - """Split points that represent compile ranges for inductor. + compile_ranges_endpoints: list[int] | None = None + """Endpoints for Inductor compile ranges. The compile ranges are - [1, split_points[0]], - [split_points[0] + 1, split_points[1]], ..., - [split_points[-1] + 1, max_num_batched_tokens]. + [1, endpoints[0]], + [endpoints[0] + 1, endpoints[1]], ..., + [endpoints[-1] + 1, max_num_batched_tokens]. Compile sizes are also used single element ranges, the range is represented as [compile_sizes[i], compile_sizes[i]]. @@ -1246,10 +1246,9 @@ class CompilationConfig: def get_compile_ranges(self) -> list[Range]: """Get the compile ranges for the compilation config.""" - if self.compile_ranges_split_points is None: + if self.compile_ranges_endpoints is None: return [] - split_points = sorted(set(self.compile_ranges_split_points)) + endpoints = sorted(set(self.compile_ranges_endpoints)) return [ - Range(start=s + 1, end=e) - for s, e in zip([0] + split_points[:-1], split_points) + Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints) ] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 682feff11..dc776fac1 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1451,12 +1451,12 @@ class VllmConfig: Set the compile ranges for the compilation config. """ compilation_config = self.compilation_config - computed_compile_ranges_split_points = [] + computed_compile_ranges_endpoints = [] # The upper bound of the compile ranges is the max_num_batched_tokens. compile_range_end = self.scheduler_config.max_num_batched_tokens if compile_range_end is not None: - computed_compile_ranges_split_points.append(compile_range_end) + computed_compile_ranges_endpoints.append(compile_range_end) # Add the compile ranges for flashinfer if compilation_config.pass_config.fuse_allreduce_rms: @@ -1468,7 +1468,7 @@ class VllmConfig: * self.model_config.dtype.itemsize ) if compile_range_end is not None and max_token_num < compile_range_end: - computed_compile_ranges_split_points.append(max_token_num) + computed_compile_ranges_endpoints.append(max_token_num) else: logger.debug( "Max num batched tokens below allreduce-rms fusion threshold, " @@ -1500,10 +1500,10 @@ class VllmConfig: and min_token_num < max_num_batched_tokens and min_token_num > 1 ): - # Add split point at min_token_num - 1 to ensure SP applies + # Add endpoint at min_token_num - 1 to ensure SP applies # starting from min_token_num # This creates ranges: [1, min-1] (no SP), [min, max] (SP applies) - computed_compile_ranges_split_points.append(min_token_num - 1) + computed_compile_ranges_endpoints.append(min_token_num - 1) if compilation_config.pass_config.fuse_rope_kvcache: max_token_num = ( @@ -1511,7 +1511,7 @@ class VllmConfig: ) if max_token_num is not None: if compile_range_end is not None and max_token_num < compile_range_end: - computed_compile_ranges_split_points.append(max_token_num) + computed_compile_ranges_endpoints.append(max_token_num) else: logger.debug( "Max num batched tokens below rope+kvcache fusion threshold, " @@ -1519,14 +1519,14 @@ class VllmConfig: compile_range_end, ) - if compilation_config.compile_ranges_split_points is not None: - for x in compilation_config.compile_ranges_split_points: + if compilation_config.compile_ranges_endpoints is not None: + for x in compilation_config.compile_ranges_endpoints: assert isinstance(x, int) - assert x > 0, f"Invalid compile range split point: {x}" + assert x > 0, f"Invalid compile range endpoint: {x}" if compile_range_end is not None and x < compile_range_end and x > 1: - computed_compile_ranges_split_points.append(x) - compilation_config.compile_ranges_split_points = sorted( - computed_compile_ranges_split_points + computed_compile_ranges_endpoints.append(x) + compilation_config.compile_ranges_endpoints = sorted( + computed_compile_ranges_endpoints ) def try_verify_and_update_config(self):