[torch.compile] Rename compile_ranges_split_points to compile_ranges_endpoints (#36027)
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -382,8 +382,8 @@ class CompilationConfig:
|
||||
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
||||
- Inductor compilation:
|
||||
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
||||
- [`compile_ranges_split_points`]
|
||||
[vllm.config.CompilationConfig.compile_ranges_split_points]
|
||||
- [`compile_ranges_endpoints`]
|
||||
[vllm.config.CompilationConfig.compile_ranges_endpoints]
|
||||
- [`inductor_compile_config`]
|
||||
[vllm.config.CompilationConfig.inductor_compile_config]
|
||||
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
|
||||
@@ -492,12 +492,12 @@ class CompilationConfig:
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture."""
|
||||
|
||||
compile_ranges_split_points: list[int] | None = None
|
||||
"""Split points that represent compile ranges for inductor.
|
||||
compile_ranges_endpoints: list[int] | None = None
|
||||
"""Endpoints for Inductor compile ranges.
|
||||
The compile ranges are
|
||||
[1, split_points[0]],
|
||||
[split_points[0] + 1, split_points[1]], ...,
|
||||
[split_points[-1] + 1, max_num_batched_tokens].
|
||||
[1, endpoints[0]],
|
||||
[endpoints[0] + 1, endpoints[1]], ...,
|
||||
[endpoints[-1] + 1, max_num_batched_tokens].
|
||||
Compile sizes are also used single element ranges,
|
||||
the range is represented as [compile_sizes[i], compile_sizes[i]].
|
||||
|
||||
@@ -1246,10 +1246,9 @@ class CompilationConfig:
|
||||
|
||||
def get_compile_ranges(self) -> list[Range]:
|
||||
"""Get the compile ranges for the compilation config."""
|
||||
if self.compile_ranges_split_points is None:
|
||||
if self.compile_ranges_endpoints is None:
|
||||
return []
|
||||
split_points = sorted(set(self.compile_ranges_split_points))
|
||||
endpoints = sorted(set(self.compile_ranges_endpoints))
|
||||
return [
|
||||
Range(start=s + 1, end=e)
|
||||
for s, e in zip([0] + split_points[:-1], split_points)
|
||||
Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
|
||||
]
|
||||
|
||||
@@ -1451,12 +1451,12 @@ class VllmConfig:
|
||||
Set the compile ranges for the compilation config.
|
||||
"""
|
||||
compilation_config = self.compilation_config
|
||||
computed_compile_ranges_split_points = []
|
||||
computed_compile_ranges_endpoints = []
|
||||
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
||||
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
||||
if compile_range_end is not None:
|
||||
computed_compile_ranges_split_points.append(compile_range_end)
|
||||
computed_compile_ranges_endpoints.append(compile_range_end)
|
||||
|
||||
# Add the compile ranges for flashinfer
|
||||
if compilation_config.pass_config.fuse_allreduce_rms:
|
||||
@@ -1468,7 +1468,7 @@ class VllmConfig:
|
||||
* self.model_config.dtype.itemsize
|
||||
)
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
computed_compile_ranges_endpoints.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below allreduce-rms fusion threshold, "
|
||||
@@ -1500,10 +1500,10 @@ class VllmConfig:
|
||||
and min_token_num < max_num_batched_tokens
|
||||
and min_token_num > 1
|
||||
):
|
||||
# Add split point at min_token_num - 1 to ensure SP applies
|
||||
# Add endpoint at min_token_num - 1 to ensure SP applies
|
||||
# starting from min_token_num
|
||||
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
|
||||
computed_compile_ranges_split_points.append(min_token_num - 1)
|
||||
computed_compile_ranges_endpoints.append(min_token_num - 1)
|
||||
|
||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||
max_token_num = (
|
||||
@@ -1511,7 +1511,7 @@ class VllmConfig:
|
||||
)
|
||||
if max_token_num is not None:
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
computed_compile_ranges_endpoints.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below rope+kvcache fusion threshold, "
|
||||
@@ -1519,14 +1519,14 @@ class VllmConfig:
|
||||
compile_range_end,
|
||||
)
|
||||
|
||||
if compilation_config.compile_ranges_split_points is not None:
|
||||
for x in compilation_config.compile_ranges_split_points:
|
||||
if compilation_config.compile_ranges_endpoints is not None:
|
||||
for x in compilation_config.compile_ranges_endpoints:
|
||||
assert isinstance(x, int)
|
||||
assert x > 0, f"Invalid compile range split point: {x}"
|
||||
assert x > 0, f"Invalid compile range endpoint: {x}"
|
||||
if compile_range_end is not None and x < compile_range_end and x > 1:
|
||||
computed_compile_ranges_split_points.append(x)
|
||||
compilation_config.compile_ranges_split_points = sorted(
|
||||
computed_compile_ranges_split_points
|
||||
computed_compile_ranges_endpoints.append(x)
|
||||
compilation_config.compile_ranges_endpoints = sorted(
|
||||
computed_compile_ranges_endpoints
|
||||
)
|
||||
|
||||
def try_verify_and_update_config(self):
|
||||
|
||||
Reference in New Issue
Block a user