[docs][torch.compile] Add fusions.md — kernel/operator fusion reference page (#35538)
Signed-off-by: ProExpertProg <luka.govedic@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: ProExpertProg <luka.govedic@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import enum
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from dataclasses import field
|
||||
from dataclasses import field, fields
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
|
||||
@@ -269,6 +269,24 @@ class PassConfig:
|
||||
)
|
||||
self.fuse_rope_kvcache = False
|
||||
|
||||
def log_enabled_passes(self) -> None:
|
||||
"""
|
||||
Log the enabled custom fusion passes.
|
||||
This is called at the end of VLLMConfig post_init,
|
||||
after all defaults are finalized.
|
||||
TODO also log the compile ranges for which this is enabled.
|
||||
"""
|
||||
enabled_fusions = [
|
||||
f.name[len("fuse_") :]
|
||||
for f in fields(self)
|
||||
if getattr(self, f.name) and f.name.startswith("fuse_")
|
||||
]
|
||||
|
||||
if enabled_fusions:
|
||||
logger.info_once(
|
||||
"Enabled custom fusions: %s", ", ".join(enabled_fusions), scope="global"
|
||||
)
|
||||
|
||||
|
||||
class DynamicShapesType(str, enum.Enum):
|
||||
"""Types of dynamic shapes handling in torch.compile().
|
||||
@@ -341,7 +359,8 @@ class CompilationConfig:
|
||||
VLLMConfig's post_init does further initialization. If used outside of the
|
||||
VLLMConfig, some fields will be left in an improper state.
|
||||
|
||||
It has three parts:
|
||||
It contains PassConfig, which controls the custom fusion/transformation passes.
|
||||
The rest has three parts:
|
||||
|
||||
- Top-level Compilation control:
|
||||
- [`mode`][vllm.config.CompilationConfig.mode]
|
||||
|
||||
@@ -1272,6 +1272,9 @@ class VllmConfig:
|
||||
# Handle the KV connector configs
|
||||
self._post_init_kv_transfer_config()
|
||||
|
||||
# Log the custom passes that are enabled
|
||||
self.compilation_config.pass_config.log_enabled_passes()
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
# enable sequence parallelism
|
||||
|
||||
Reference in New Issue
Block a user