[torch.compile] Rename compile_ranges_split_points to compile_ranges_endpoints (#36027)
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -46,10 +46,10 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Get the compile ranges split points after vllm config post init
|
||||
# Get the compile ranges endpoints after vllm config post init
|
||||
# in order to compute compile ranges correctly
|
||||
compilation_config.compile_ranges_split_points = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
||||
compilation_config.compile_ranges_endpoints = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_split_points=[8, 32],
|
||||
compile_ranges_endpoints=[8, 32],
|
||||
compile_sizes=[16, 64, 128],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": post_grad_range_checker,
|
||||
@@ -110,7 +110,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
|
||||
|
||||
def test_compile_config_get_compile_ranges():
|
||||
compilation_config = CompilationConfig(
|
||||
compile_ranges_split_points=[8, 32],
|
||||
compile_ranges_endpoints=[8, 32],
|
||||
)
|
||||
VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
@@ -149,7 +149,7 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
||||
scheduler_config=scheduler_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_split_points=[8],
|
||||
compile_ranges_endpoints=[8],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": post_grad_range_checker,
|
||||
},
|
||||
|
||||
@@ -885,8 +885,8 @@ class VllmBackend:
|
||||
"splitting_ops": list_to_str(cc.splitting_ops),
|
||||
"cudagraph_mode": str(cc.cudagraph_mode),
|
||||
"compile_sizes": list_to_str(cc.compile_sizes),
|
||||
"compile_ranges_split_points": list_to_str(
|
||||
cc.compile_ranges_split_points
|
||||
"compile_ranges_endpoints": list_to_str(
|
||||
cc.compile_ranges_endpoints
|
||||
),
|
||||
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
|
||||
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
|
||||
|
||||
@@ -382,8 +382,8 @@ class CompilationConfig:
|
||||
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
||||
- Inductor compilation:
|
||||
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
||||
- [`compile_ranges_split_points`]
|
||||
[vllm.config.CompilationConfig.compile_ranges_split_points]
|
||||
- [`compile_ranges_endpoints`]
|
||||
[vllm.config.CompilationConfig.compile_ranges_endpoints]
|
||||
- [`inductor_compile_config`]
|
||||
[vllm.config.CompilationConfig.inductor_compile_config]
|
||||
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
|
||||
@@ -492,12 +492,12 @@ class CompilationConfig:
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture."""
|
||||
|
||||
compile_ranges_split_points: list[int] | None = None
|
||||
"""Split points that represent compile ranges for inductor.
|
||||
compile_ranges_endpoints: list[int] | None = None
|
||||
"""Endpoints for Inductor compile ranges.
|
||||
The compile ranges are
|
||||
[1, split_points[0]],
|
||||
[split_points[0] + 1, split_points[1]], ...,
|
||||
[split_points[-1] + 1, max_num_batched_tokens].
|
||||
[1, endpoints[0]],
|
||||
[endpoints[0] + 1, endpoints[1]], ...,
|
||||
[endpoints[-1] + 1, max_num_batched_tokens].
|
||||
Compile sizes are also used single element ranges,
|
||||
the range is represented as [compile_sizes[i], compile_sizes[i]].
|
||||
|
||||
@@ -1246,10 +1246,9 @@ class CompilationConfig:
|
||||
|
||||
def get_compile_ranges(self) -> list[Range]:
|
||||
"""Get the compile ranges for the compilation config."""
|
||||
if self.compile_ranges_split_points is None:
|
||||
if self.compile_ranges_endpoints is None:
|
||||
return []
|
||||
split_points = sorted(set(self.compile_ranges_split_points))
|
||||
endpoints = sorted(set(self.compile_ranges_endpoints))
|
||||
return [
|
||||
Range(start=s + 1, end=e)
|
||||
for s, e in zip([0] + split_points[:-1], split_points)
|
||||
Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
|
||||
]
|
||||
|
||||
@@ -1451,12 +1451,12 @@ class VllmConfig:
|
||||
Set the compile ranges for the compilation config.
|
||||
"""
|
||||
compilation_config = self.compilation_config
|
||||
computed_compile_ranges_split_points = []
|
||||
computed_compile_ranges_endpoints = []
|
||||
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
||||
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
||||
if compile_range_end is not None:
|
||||
computed_compile_ranges_split_points.append(compile_range_end)
|
||||
computed_compile_ranges_endpoints.append(compile_range_end)
|
||||
|
||||
# Add the compile ranges for flashinfer
|
||||
if compilation_config.pass_config.fuse_allreduce_rms:
|
||||
@@ -1468,7 +1468,7 @@ class VllmConfig:
|
||||
* self.model_config.dtype.itemsize
|
||||
)
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
computed_compile_ranges_endpoints.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below allreduce-rms fusion threshold, "
|
||||
@@ -1500,10 +1500,10 @@ class VllmConfig:
|
||||
and min_token_num < max_num_batched_tokens
|
||||
and min_token_num > 1
|
||||
):
|
||||
# Add split point at min_token_num - 1 to ensure SP applies
|
||||
# Add endpoint at min_token_num - 1 to ensure SP applies
|
||||
# starting from min_token_num
|
||||
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
|
||||
computed_compile_ranges_split_points.append(min_token_num - 1)
|
||||
computed_compile_ranges_endpoints.append(min_token_num - 1)
|
||||
|
||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||
max_token_num = (
|
||||
@@ -1511,7 +1511,7 @@ class VllmConfig:
|
||||
)
|
||||
if max_token_num is not None:
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
computed_compile_ranges_endpoints.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below rope+kvcache fusion threshold, "
|
||||
@@ -1519,14 +1519,14 @@ class VllmConfig:
|
||||
compile_range_end,
|
||||
)
|
||||
|
||||
if compilation_config.compile_ranges_split_points is not None:
|
||||
for x in compilation_config.compile_ranges_split_points:
|
||||
if compilation_config.compile_ranges_endpoints is not None:
|
||||
for x in compilation_config.compile_ranges_endpoints:
|
||||
assert isinstance(x, int)
|
||||
assert x > 0, f"Invalid compile range split point: {x}"
|
||||
assert x > 0, f"Invalid compile range endpoint: {x}"
|
||||
if compile_range_end is not None and x < compile_range_end and x > 1:
|
||||
computed_compile_ranges_split_points.append(x)
|
||||
compilation_config.compile_ranges_split_points = sorted(
|
||||
computed_compile_ranges_split_points
|
||||
computed_compile_ranges_endpoints.append(x)
|
||||
compilation_config.compile_ranges_endpoints = sorted(
|
||||
computed_compile_ranges_endpoints
|
||||
)
|
||||
|
||||
def try_verify_and_update_config(self):
|
||||
|
||||
Reference in New Issue
Block a user