[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1,21 +1,38 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
_forward_context: Any = None
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
def get_forward_context() -> Any:
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
static_forward_context: Dict[str, Any]
|
||||
# TODO: extend to support per-layer dynamic forward context
|
||||
dynamic_forward_context: Any
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
|
||||
|
||||
def get_forward_context() -> ForwardContext:
|
||||
"""Get the current forward context."""
|
||||
assert _forward_context is not None, (
|
||||
"Forward context is not set. "
|
||||
"Please use `set_forward_context` to set the forward context.")
|
||||
return _forward_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(context: Any):
|
||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = context
|
||||
_forward_context = ForwardContext(
|
||||
static_forward_context=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
dynamic_forward_context=context)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user