[torch.compile] Hide KV cache behind torch.compile boundary (#11677)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-01-10 13:14:42 +08:00
committed by GitHub
parent 3de2b1eafb
commit cf5f000d21
18 changed files with 198 additions and 44 deletions

View File

@@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
LayerBlockType, bind_kv_cache, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
@@ -860,3 +861,6 @@ class GPUModelRunner:
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[self.kv_caches])