[v1] Refactor KVCacheConfig (#14079)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -1510,34 +1510,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.groups) > 1:
|
||||
if len(kv_cache_config.kv_cache_groups) > 1:
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
||||
layer_spec.head_size)
|
||||
dtype = layer_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||
# KVCacheManager may allocate.
|
||||
# Since different GPUs may have different number of layers and
|
||||
# different memory capacities, `num_blocks` can be different on
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
@@ -1549,7 +1561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
|
||||
@@ -185,7 +185,7 @@ class Worker(WorkerBase):
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
||||
@@ -309,7 +309,7 @@ class TPUModelRunner:
|
||||
assert self.model is not None
|
||||
return self.model
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
@@ -320,7 +320,7 @@ class TPUModelRunner:
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
@@ -837,31 +837,33 @@ class TPUModelRunner:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.groups) > 1:
|
||||
if len(kv_cache_config.kv_cache_groups) > 1:
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
||||
layer_spec.head_size)
|
||||
dtype = layer_spec.dtype
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
|
||||
@@ -189,7 +189,7 @@ class TPUWorker:
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
||||
@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
|
||||
self.device: Optional[torch.device] = None
|
||||
self.model_runner: Optional[nn.Module] = None
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user