40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional
|
|
|
|
from vllm.config import VllmConfig
|
|
|
|
|
|
@dataclass
|
|
class ForwardContext:
|
|
static_forward_context: Dict[str, Any]
|
|
# TODO: extend to support per-layer dynamic forward context
|
|
dynamic_forward_context: Any
|
|
|
|
|
|
_forward_context: Optional[ForwardContext] = None
|
|
|
|
|
|
def get_forward_context() -> ForwardContext:
|
|
"""Get the current forward context."""
|
|
assert _forward_context is not None, (
|
|
"Forward context is not set. "
|
|
"Please use `set_forward_context` to set the forward context.")
|
|
return _forward_context
|
|
|
|
|
|
@contextmanager
|
|
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
|
"""A context manager that stores the current forward context,
|
|
can be attention metadata, etc."""
|
|
global _forward_context
|
|
prev_context = _forward_context
|
|
_forward_context = ForwardContext(
|
|
static_forward_context=vllm_config.compilation_config.
|
|
static_forward_context,
|
|
dynamic_forward_context=context)
|
|
try:
|
|
yield
|
|
finally:
|
|
_forward_context = prev_context
|