Support multiple attention groups for KV sharing (#22672)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -225,26 +225,34 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
Note that layers in shared_kv_cache_layers.keys() are not
|
||||
originally included as it only contains layers which have its own
|
||||
KV cache allocation.
|
||||
attn_groups: Optional list of attention groups. Layers in the same KV
|
||||
cache group may be placed in different attention groups if they
|
||||
have different attention backends. Currently only provided by
|
||||
GPU model runner.
|
||||
"""
|
||||
# Record index of KV cache group for each layer that allocates a KV cache.
|
||||
layer_to_kv_cache_group_idx: dict[str, int] = {}
|
||||
for i, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group_idx[layer_name] = i
|
||||
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
|
||||
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
|
||||
if attn_groups:
|
||||
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
|
||||
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
|
||||
for layer_name in attn_group.layer_names:
|
||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
|
||||
attn_group_idx)
|
||||
else:
|
||||
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
# attn group idx default to 0 if not provided
|
||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
|
||||
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
|
||||
|
||||
if attn_groups is not None:
|
||||
assert len(attn_groups[group_idx]) == 1, (
|
||||
"Only one attention group per KV cache group is supported "
|
||||
"for KV-cache sharing for now.")
|
||||
# TODO(lucas): I think in the future the layers that re-use a
|
||||
# KV cache will be in a different attention group so we can
|
||||
# remove this code from here.
|
||||
attn_groups[group_idx][0].layer_names.append(layer_name)
|
||||
if attn_groups:
|
||||
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
|
||||
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
||||
layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
|
||||
Reference in New Issue
Block a user