[torch.compile] support all attention backends (#10558)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-22 14:04:42 -08:00
committed by GitHub
parent db100c5cde
commit eebad39f26
77 changed files with 876 additions and 648 deletions

View File

@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention",
"vllm.unified_flash_infer",
"vllm.unified_attention",
"vllm.unified_v1_flash_attention",
])
@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.static_forward_context = {}
def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: