[torch.compile] Rename compile_ranges_split_points to compile_ranges_endpoints (#36027)

Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Copilot
2026-03-09 18:04:40 +00:00
committed by GitHub
parent fa028207aa
commit 4b87ffbefb
5 changed files with 30 additions and 31 deletions

View File

@@ -46,10 +46,10 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Get the compile ranges split points after vllm config post init
# Get the compile ranges endpoints after vllm config post init
# in order to compute compile ranges correctly
compilation_config.compile_ranges_split_points = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
compilation_config.compile_ranges_endpoints = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
)

View File

@@ -86,7 +86,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32],
compile_ranges_endpoints=[8, 32],
compile_sizes=[16, 64, 128],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
@@ -110,7 +110,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
def test_compile_config_get_compile_ranges():
compilation_config = CompilationConfig(
compile_ranges_split_points=[8, 32],
compile_ranges_endpoints=[8, 32],
)
VllmConfig(
scheduler_config=SchedulerConfig(
@@ -149,7 +149,7 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
scheduler_config=scheduler_config,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8],
compile_ranges_endpoints=[8],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},

View File

@@ -885,8 +885,8 @@ class VllmBackend:
"splitting_ops": list_to_str(cc.splitting_ops),
"cudagraph_mode": str(cc.cudagraph_mode),
"compile_sizes": list_to_str(cc.compile_sizes),
"compile_ranges_split_points": list_to_str(
cc.compile_ranges_split_points
"compile_ranges_endpoints": list_to_str(
cc.compile_ranges_endpoints
),
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),

View File

@@ -382,8 +382,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`]
[vllm.config.CompilationConfig.compile_ranges_split_points]
- [`compile_ranges_endpoints`]
[vllm.config.CompilationConfig.compile_ranges_endpoints]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
@@ -492,12 +492,12 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
compile_ranges_endpoints: list[int] | None = None
"""Endpoints for Inductor compile ranges.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
[1, endpoints[0]],
[endpoints[0] + 1, endpoints[1]], ...,
[endpoints[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
@@ -1246,10 +1246,9 @@ class CompilationConfig:
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
if self.compile_ranges_endpoints is None:
return []
split_points = sorted(set(self.compile_ranges_split_points))
endpoints = sorted(set(self.compile_ranges_endpoints))
return [
Range(start=s + 1, end=e)
for s, e in zip([0] + split_points[:-1], split_points)
Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
]

View File

@@ -1451,12 +1451,12 @@ class VllmConfig:
Set the compile ranges for the compilation config.
"""
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
computed_compile_ranges_endpoints = []
# The upper bound of the compile ranges is the max_num_batched_tokens.
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
computed_compile_ranges_split_points.append(compile_range_end)
computed_compile_ranges_endpoints.append(compile_range_end)
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
@@ -1468,7 +1468,7 @@ class VllmConfig:
* self.model_config.dtype.itemsize
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
@@ -1500,10 +1500,10 @@ class VllmConfig:
and min_token_num < max_num_batched_tokens
and min_token_num > 1
):
# Add split point at min_token_num - 1 to ensure SP applies
# Add endpoint at min_token_num - 1 to ensure SP applies
# starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1)
computed_compile_ranges_endpoints.append(min_token_num - 1)
if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
@@ -1511,7 +1511,7 @@ class VllmConfig:
)
if max_token_num is not None:
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below rope+kvcache fusion threshold, "
@@ -1519,14 +1519,14 @@ class VllmConfig:
compile_range_end,
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
assert x > 0, f"Invalid compile range endpoint: {x}"
if compile_range_end is not None and x < compile_range_end and x > 1:
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
computed_compile_ranges_endpoints.append(x)
compilation_config.compile_ranges_endpoints = sorted(
computed_compile_ranges_endpoints
)
def try_verify_and_update_config(self):