[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||
"vllm.unified_flash_attention",
|
||||
"vllm.unified_flash_infer",
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_v1_flash_attention",
|
||||
])
|
||||
|
||||
@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Mainly used to store attention cls
|
||||
# Map from layer name to the attention cls
|
||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
self.static_forward_context = {}
|
||||
|
||||
def init_backend(self) -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
|
||||
Reference in New Issue
Block a user