[V1] Support cross-layer KV sharing (#18212)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -7,8 +7,11 @@ import pytest
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||
get_kv_cache_config)
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@@ -19,6 +22,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
@@ -55,8 +59,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
def get_vllm_config():
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
@@ -84,13 +87,18 @@ def model_runner():
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
vllm_config = get_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context[
|
||||
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
||||
|
||||
device = "cuda"
|
||||
runner = GPUModelRunner(vllm_config, device)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
initialize_kv_cache(runner)
|
||||
return runner
|
||||
|
||||
@@ -385,3 +393,225 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
model_runner_2.load_model() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
error_msg = f"{layer_1} must come before the current layer"
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
fwd_context = {
|
||||
# initialization below will fail because target layer is invalid;
|
||||
# the target layer needs to come before layer 1
|
||||
layer_0:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
kv_sharing_target_layer_name=layer_1,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
)
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
invalid_layer = "model.layers.0.cross_attn.attn"
|
||||
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
fwd_context = {
|
||||
layer_0:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
# invalid layer: cross_attn.atn doesn't exist!
|
||||
kv_sharing_target_layer_name=invalid_layer,
|
||||
)
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
error_msg = f"{layer_1} cannot be the same as the current layer"
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
fwd_context = {
|
||||
# initialization below will fail because target layer is invalid;
|
||||
# the target layer needs to come before layer 1
|
||||
layer_0:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
kv_sharing_target_layer_name=layer_1,
|
||||
)
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_without_kv_sharing():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
vllm_config = get_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
fwd_context = {
|
||||
layer_0:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
)
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
# Set high context length to test max context length estimation
|
||||
vllm_config.model_config.max_model_len = 3_000_000
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
assert len(kv_cache_spec) == 2
|
||||
assert len(runner.shared_kv_cache_layers) == 0
|
||||
|
||||
available_memory = 20 * GiB_bytes
|
||||
# page size for layer 0's kv_cache_spec is 32KB
|
||||
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.tensors) == 2
|
||||
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
|
||||
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
|
||||
|
||||
max_context_len =\
|
||||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
# max context len with KV sharing should be 2x as large as without
|
||||
assert max_context_len == 1310720
|
||||
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 2 block worth of memory (2 * 32kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
for layer in kv_cache_config.tensors:
|
||||
kv_cache_config.tensors[layer].size =\
|
||||
kv_cache_spec[layer].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
# check layer 1 kv cache does NOT share memory with layer 0
|
||||
assert id(layer_1_kv) != id(layer_0_kv)
|
||||
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_valid():
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
vllm_config = get_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
fwd_context = {
|
||||
layer_0:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
||||
)
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
# Set high context length to test max context length estimation
|
||||
vllm_config.model_config.max_model_len = 3_000_000
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
assert len(kv_cache_spec) == 1
|
||||
assert layer_0 in kv_cache_spec
|
||||
assert runner.shared_kv_cache_layers[layer_1] == layer_0
|
||||
|
||||
available_memory = 20 * GiB_bytes
|
||||
# page size for layer 0's kv_cache_spec is 32KB
|
||||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||
# which is twice as many as without KV sharing
|
||||
num_expected_blocks = 655360 # 20GB / 32KB
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.tensors) == 1
|
||||
# Each layer now has twice the available memory for KV cache
|
||||
# compared to no KV sharing
|
||||
assert kv_cache_config.tensors[layer_0].size == available_memory
|
||||
|
||||
max_context_len =\
|
||||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
# max context len with KV sharing should be 2x as large as without
|
||||
assert max_context_len == 2 * 1310720
|
||||
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 1 block worth of memory (32kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
kv_cache_config.tensors[layer_0].size =\
|
||||
kv_cache_spec[layer_0].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
# check layer 1 kv cache shares memory with layer 0
|
||||
assert id(layer_1_kv) == id(layer_0_kv)
|
||||
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
Reference in New Issue
Block a user