[torch.compile] Rename compile_ranges_split_points to compile_ranges_endpoints (#36027)

Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Copilot
2026-03-09 18:04:40 +00:00
committed by GitHub
parent fa028207aa
commit 4b87ffbefb
5 changed files with 30 additions and 31 deletions

View File

@@ -382,8 +382,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`]
[vllm.config.CompilationConfig.compile_ranges_split_points]
- [`compile_ranges_endpoints`]
[vllm.config.CompilationConfig.compile_ranges_endpoints]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
@@ -492,12 +492,12 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
compile_ranges_endpoints: list[int] | None = None
"""Endpoints for Inductor compile ranges.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
[1, endpoints[0]],
[endpoints[0] + 1, endpoints[1]], ...,
[endpoints[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
@@ -1246,10 +1246,9 @@ class CompilationConfig:
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
if self.compile_ranges_endpoints is None:
return []
split_points = sorted(set(self.compile_ranges_split_points))
endpoints = sorted(set(self.compile_ranges_endpoints))
return [
Range(start=s + 1, end=e)
for s, e in zip([0] + split_points[:-1], split_points)
Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
]

View File

@@ -1451,12 +1451,12 @@ class VllmConfig:
Set the compile ranges for the compilation config.
"""
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
computed_compile_ranges_endpoints = []
# The upper bound of the compile ranges is the max_num_batched_tokens.
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
computed_compile_ranges_split_points.append(compile_range_end)
computed_compile_ranges_endpoints.append(compile_range_end)
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
@@ -1468,7 +1468,7 @@ class VllmConfig:
* self.model_config.dtype.itemsize
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
@@ -1500,10 +1500,10 @@ class VllmConfig:
and min_token_num < max_num_batched_tokens
and min_token_num > 1
):
# Add split point at min_token_num - 1 to ensure SP applies
# Add endpoint at min_token_num - 1 to ensure SP applies
# starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1)
computed_compile_ranges_endpoints.append(min_token_num - 1)
if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
@@ -1511,7 +1511,7 @@ class VllmConfig:
)
if max_token_num is not None:
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below rope+kvcache fusion threshold, "
@@ -1519,14 +1519,14 @@ class VllmConfig:
compile_range_end,
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
assert x > 0, f"Invalid compile range endpoint: {x}"
if compile_range_end is not None and x < compile_range_end and x > 1:
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
computed_compile_ranges_endpoints.append(x)
compilation_config.compile_ranges_endpoints = sorted(
computed_compile_ranges_endpoints
)
def try_verify_and_update_config(self):