[KVConnector][Core] Support cross-layer KV blocks (#27743)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-11-20 20:09:59 +02:00
committed by GitHub
parent e5bfcb6a88
commit 647464719b
15 changed files with 453 additions and 90 deletions

View File

@@ -12,8 +12,10 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [16, 48]
CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
class MockSubscriber:
@@ -63,8 +65,88 @@ class MockSubscriber:
self.sub.close()
def _latency_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)
num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time
# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time
# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()
assert subscriber.get_new_cpu_stored_events()
# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time
if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1
print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)
cpu_block_size = (
llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[
"block_size"
]
)
subscriber.get_new_cpu_stored_events()
# prepend prompt to be cpu block aligned
prompt = "Let's count to 10. One, two, three, four,"
while (
len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size
!= 0
):
prompt = ". " + prompt
assert subscriber.get_new_cpu_stored_events()
test_count = 100
success_count = 0
for i in range(test_count):
if (
llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text
== " five"
):
success_count += 1
assert success_count >= 0.5 * test_count
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""
@@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
topic="test",
)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)
sampling_params = SamplingParams(temperature=0, max_tokens=1)
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
try:
num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time
# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time
# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()
assert subscriber.get_new_cpu_stored_events()
# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time
if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1
print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
_latency_test(llm, subscriber)
_accuracy_test(llm, subscriber)
finally:
subscriber.close()
del llm