[torch.compile] Hide KV cache behind torch.compile boundary (#11677)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-01-10 13:14:42 +08:00
committed by GitHub
parent 3de2b1eafb
commit cf5f000d21
18 changed files with 198 additions and 44 deletions

View File

@@ -2,7 +2,7 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
import torch
@@ -10,6 +10,9 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None
@@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
attn_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata)
try:
yield
finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"):
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + \
context.num_decode_tokens
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch