[torch.compile] support all attention backends (#10558)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-22 14:04:42 -08:00
committed by GitHub
parent db100c5cde
commit eebad39f26
77 changed files with 876 additions and 648 deletions

View File

@@ -1,21 +1,38 @@
from contextlib import contextmanager
from typing import Any
from dataclasses import dataclass
from typing import Any, Dict, Optional
_forward_context: Any = None
from vllm.config import VllmConfig
def get_forward_context() -> Any:
@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return _forward_context
@contextmanager
def set_forward_context(context: Any):
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global _forward_context
prev_context = _forward_context
_forward_context = context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
try:
yield
finally: