[torch.compile] Hide KV cache behind torch.compile boundary (#11677)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,6 +10,9 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
@@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
static_forward_context: Dict[str, Any]
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
attn_layers: Dict[str, Any]
|
||||
# TODO: extend to support per-layer dynamic forward context
|
||||
dynamic_forward_context: Any
|
||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
@@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and context is not None
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
static_forward_context=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
dynamic_forward_context=context)
|
||||
attn_layers=vllm_config.compilation_config.static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
global batchsize_counter
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
if hasattr(context, "num_prefill_tokens"):
|
||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = context.num_prefill_tokens + \
|
||||
context.num_decode_tokens
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = context.num_input_tokens
|
||||
batchsize = attn_metadata.num_input_tokens
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
|
||||
Reference in New Issue
Block a user