Compare commits
7 Commits
v0.18.2rc0
...
v0.19.0rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dbbafd3f3 | ||
|
|
0ee3b7fc3d | ||
|
|
268bed9cf3 | ||
|
|
bcc0fdd0f3 | ||
|
|
69b8bd4b33 | ||
|
|
12449f9492 | ||
|
|
b92312dfd7 |
@@ -1,9 +1,10 @@
|
||||
#!/bin/bash
|
||||
set -euox pipefail
|
||||
export VLLM_CPU_CI_ENV=0
|
||||
export VLLM_CPU_KVCACHE_SPACE=1 # avoid OOM
|
||||
|
||||
echo "--- PP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
@@ -23,7 +24,7 @@ if [ "$failed_req" -ne 0 ]; then
|
||||
fi
|
||||
|
||||
echo "--- DP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
|
||||
@@ -7,3 +7,4 @@ server_args: >-
|
||||
--max-model-len 4096
|
||||
--data-parallel-size 2
|
||||
--enable-expert-parallel
|
||||
--max-num-seqs 512
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import (
|
||||
BertMLMHead,
|
||||
SPLADESparsePooler,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Functional test: SPLADE formula correctness (no HF download needed)
|
||||
@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
],
|
||||
dtype=torch.long,
|
||||
)
|
||||
meta = types.SimpleNamespace(
|
||||
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
|
||||
meta = PoolingMetadata(
|
||||
prompt_lens=prompt_lens_tenser,
|
||||
prompt_token_ids=token_ids,
|
||||
prompt_token_ids_cpu=token_ids,
|
||||
pooling_params=[PoolingParams(task="embed")] * B,
|
||||
pooling_states=[PoolingStates() for _ in range(B)],
|
||||
)
|
||||
|
||||
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
|
||||
|
||||
@@ -15,6 +15,10 @@ from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
|
||||
COLBERT_DIM = 128
|
||||
DTYPE = "half"
|
||||
# Fixme:
|
||||
# Update colmodernvbert code to support the latest HF version
|
||||
# and remove revision set.
|
||||
REVISION = "4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -26,6 +30,7 @@ def test_colmodernvbert_text_token_embed(vllm_runner):
|
||||
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -49,6 +54,7 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner):
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -66,6 +72,7 @@ def test_colmodernvbert_text_late_interaction(vllm_runner):
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -92,6 +99,7 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
|
||||
"""Image input produces per-token embeddings including vision tokens."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
|
||||
@@ -636,6 +636,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
|
||||
# [Multimodal]
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
revision="4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee",
|
||||
),
|
||||
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
|
||||
"ColQwen3": _HfExamplesInfo(
|
||||
|
||||
@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
from vllm.v1.attention.ops import flashmla
|
||||
|
||||
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
|
||||
[
|
||||
# Logits constraint triggers split (M*N exceeds budget)
|
||||
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
|
||||
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
|
||||
(
|
||||
torch.tensor([100, 100, 100]),
|
||||
torch.tensor([10, 10, 10]),
|
||||
1000, # workspace allows all
|
||||
5000, # 1250 float32 elems -> forces split
|
||||
[
|
||||
(slice(0, 1), slice(0, 10)),
|
||||
(slice(1, 2), slice(0, 10)),
|
||||
(slice(2, 3), slice(0, 10)),
|
||||
],
|
||||
),
|
||||
# Both constraints satisfied - all fit in one chunk
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
|
||||
[(slice(0, 3), slice(0, 15))],
|
||||
),
|
||||
# Workspace constraint triggers first
|
||||
(
|
||||
torch.tensor([50, 50, 50]),
|
||||
torch.tensor([1, 1, 1]),
|
||||
50, # workspace only fits one at a time
|
||||
1000000, # logits budget is huge
|
||||
[
|
||||
(slice(0, 1), slice(0, 1)),
|
||||
(slice(1, 2), slice(0, 1)),
|
||||
(slice(2, 3), slice(0, 1)),
|
||||
],
|
||||
),
|
||||
# Greedy filling: first two fit, third doesn't
|
||||
# req0: M=5, N=10 -> 50 elems
|
||||
# req0+1: M=10, N=20 -> 200 elems <= 250
|
||||
# req0+1+2: M=15, N=30 -> 450 elems > 250
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
1000, # 250 elems
|
||||
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_indexer_prefill_chunks(
|
||||
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
|
||||
):
|
||||
out = split_indexer_prefill_chunks(
|
||||
seq_lens,
|
||||
query_lens,
|
||||
workspace_size,
|
||||
max_logits_bytes,
|
||||
)
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_split_indexer_prefill_chunks_single_request_overflow():
|
||||
"""Test that single request exceeding budget is sub-chunked on query dim."""
|
||||
seq_lens = torch.tensor([1000, 50])
|
||||
query_lens = torch.tensor([100, 5])
|
||||
|
||||
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
|
||||
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
|
||||
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
|
||||
# req1: M=5, N=50 -> 250 elems fits budget
|
||||
expected.append((slice(1, 2), slice(0, 5)))
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_triton_convert_returns_valid_counts():
|
||||
"""Test that return_valid_counts correctly counts non-negative indices."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
0
tests/v1/simple_kv_offload/__init__.py
Normal file
0
tests/v1/simple_kv_offload/__init__.py
Normal file
193
tests/v1/simple_kv_offload/test_integration.py
Normal file
193
tests/v1/simple_kv_offload/test_integration.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for SimpleCPUOffloadConnector with real models."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("Requires CUDA", allow_module_level=True)
|
||||
|
||||
# Small models for default CI / local runs (accuracy only).
|
||||
SMALL_MODELS = [
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"google/gemma-3-1b-it",
|
||||
]
|
||||
|
||||
# Large models for optional perf runs only (slow to load and execute).
|
||||
PERF_MODELS = [
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
"openai/gpt-oss-20b",
|
||||
]
|
||||
|
||||
|
||||
def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM:
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SimpleCPUOffloadConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"cpu_bytes_to_use": cpu_bytes_to_use,
|
||||
"lazy_offload": lazy,
|
||||
},
|
||||
)
|
||||
return LLM(
|
||||
model=model,
|
||||
kv_cache_memory_bytes=40 << 30, # 40 GiB
|
||||
disable_hybrid_kv_cache_manager=False,
|
||||
enable_prefix_caching=True,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
|
||||
|
||||
def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0):
|
||||
"""Generate enough filler requests to allocate the entire GPU KV cache.
|
||||
|
||||
This pushes all prior blocks through the free queue so that the lazy
|
||||
cursor offloads them to CPU before they are evicted.
|
||||
"""
|
||||
cache_config = llm.llm_engine.vllm_config.cache_config
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
block_size = cache_config.block_size
|
||||
# Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps
|
||||
# to walk past all target blocks near the tail of the free queue.
|
||||
total_tokens_needed = int(num_gpu_blocks * block_size * 1.5)
|
||||
|
||||
# Use token-id prompts so each filler is unique (no prefix sharing).
|
||||
# Split into multiple requests to stay under max_model_len.
|
||||
max_tokens_per_req = 4096
|
||||
num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req
|
||||
batch_size = 10
|
||||
for i in range(0, num_fillers, batch_size):
|
||||
batch_end = min(i + batch_size, num_fillers)
|
||||
filler_prompts = []
|
||||
for j in range(i, batch_end):
|
||||
ids = [seed * num_fillers + j + 1] * max_tokens_per_req
|
||||
filler_prompts.append(TokensPrompt(prompt_token_ids=ids))
|
||||
llm.generate(filler_prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
|
||||
def _accuracy_test(llm: LLM, lazy: bool = False):
|
||||
"""Verify that CPU-loaded KV produces correct output."""
|
||||
sampling_params = SamplingParams(max_tokens=1, temperature=0)
|
||||
prompt = "hi " * 2000 + "Let's count to ten. One, two, three, "
|
||||
|
||||
# Cold run — populate GPU cache and trigger CPU offload
|
||||
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
|
||||
|
||||
# CPU hit runs
|
||||
test_count = 10
|
||||
success_count = 0
|
||||
expected = cold_output.outputs[0].text
|
||||
for i in range(test_count):
|
||||
if lazy:
|
||||
_flush_gpu_cache(llm, sampling_params, seed=i)
|
||||
time.sleep(2) # let engine core drain pending transfers
|
||||
|
||||
# Reset GPU prefix cache so next run must load from CPU
|
||||
if not llm.reset_prefix_cache():
|
||||
print(f"GPU prefix cache reset failed for iteration {i}")
|
||||
|
||||
output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
|
||||
if output.outputs[0].text == expected:
|
||||
success_count += 1
|
||||
|
||||
assert success_count >= 0.5 * test_count, (
|
||||
f"Accuracy too low: {success_count}/{test_count} matched '{expected}'"
|
||||
)
|
||||
|
||||
|
||||
def _latency_test(llm: LLM, lazy: bool = False):
|
||||
"""Verify CPU cache hit is faster than cold compute."""
|
||||
sampling_params = SamplingParams(max_tokens=1, seed=42)
|
||||
prompt_token_ids = [0] * 10001
|
||||
|
||||
num_times_cpu_better = 0
|
||||
num_tests = 10
|
||||
for i in range(num_tests):
|
||||
prompt_token_ids[0] = i
|
||||
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
|
||||
|
||||
# Cold
|
||||
time.sleep(2) # let engine core drain pending transfers
|
||||
if not llm.reset_prefix_cache():
|
||||
print(f"GPU prefix cache reset failed for iteration {i}")
|
||||
start = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start
|
||||
|
||||
if lazy:
|
||||
_flush_gpu_cache(llm, sampling_params, seed=i)
|
||||
else:
|
||||
# Eager mode: GPU hit ensures store completion is processed.
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
time.sleep(2) # let engine core drain pending transfers
|
||||
if not llm.reset_prefix_cache():
|
||||
print(f"GPU prefix cache reset failed for iteration {i}")
|
||||
|
||||
# CPU hit
|
||||
start = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_time = time.time() - start
|
||||
|
||||
if cpu_time < cold_time:
|
||||
num_times_cpu_better += 1
|
||||
|
||||
assert num_times_cpu_better >= 0.8 * num_tests, (
|
||||
f"CPU hit only faster {num_times_cpu_better}/{num_tests} times"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", SMALL_MODELS)
|
||||
def test_simple_cpu_offload_accuracy(model: str):
|
||||
"""Store to CPU, reset GPU, load from CPU; verify output matches baseline."""
|
||||
llm = _make_llm(model, False, 1 << 30) # 1GB
|
||||
try:
|
||||
_accuracy_test(llm, lazy=False)
|
||||
finally:
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", PERF_MODELS)
|
||||
def test_simple_cpu_offload_perf_latency(model: str):
|
||||
"""CPU KV hit should beat cold prefill on long context (large models only)."""
|
||||
llm = _make_llm(model, False, 10 << 30) # 10GB
|
||||
try:
|
||||
_latency_test(llm, lazy=False)
|
||||
finally:
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", SMALL_MODELS)
|
||||
def test_simple_cpu_offload_accuracy_lazy(model: str):
|
||||
"""Lazy mode: flush GPU cache to trigger CPU offload, then verify hit."""
|
||||
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
|
||||
llm = _make_llm(model, True, 80 << 30) # 80GB
|
||||
try:
|
||||
_accuracy_test(llm, lazy=True)
|
||||
finally:
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", PERF_MODELS)
|
||||
def test_simple_cpu_offload_perf_latency_lazy(model: str):
|
||||
"""Lazy mode: CPU KV hit should beat cold prefill (large models only)."""
|
||||
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
|
||||
llm = _make_llm(model, True, 80 << 30) # 80GB
|
||||
try:
|
||||
_latency_test(llm, lazy=True)
|
||||
finally:
|
||||
del llm
|
||||
1137
tests/v1/simple_kv_offload/test_scheduler.py
Normal file
1137
tests/v1/simple_kv_offload/test_scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
147
tests/v1/spec_decode/test_backup_token_async_spec.py
Normal file
147
tests/v1/spec_decode/test_backup_token_async_spec.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Regression tests for the backup token fix in prepare_next_token_ids_padded.
|
||||
|
||||
Fixes #38098: with async scheduling, seq_lens_cpu is inflated by unaccepted
|
||||
draft token placeholders, causing get_token_id() to return -1.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, prompt_tokens: list[int], output_tokens: list[int]):
|
||||
self.num_prompt_tokens = len(prompt_tokens)
|
||||
self._prompt = prompt_tokens
|
||||
self._output = output_tokens
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self._output)
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self._prompt[idx]
|
||||
out_idx = idx - self.num_prompt_tokens
|
||||
if out_idx < len(self._output):
|
||||
return self._output[out_idx]
|
||||
return -1 # out of range
|
||||
|
||||
|
||||
class _FakeInputBatch:
|
||||
def __init__(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
num_tokens_no_spec: list[int],
|
||||
vocab_size: int = 32000,
|
||||
):
|
||||
self.req_ids = req_ids
|
||||
self.num_reqs = len(req_ids)
|
||||
self.vocab_size = vocab_size
|
||||
self.num_tokens_no_spec = np.array(num_tokens_no_spec, dtype=np.int64)
|
||||
|
||||
|
||||
def _make_requests(
|
||||
req_ids: list[str],
|
||||
prompt_lens: list[int],
|
||||
output_lens: list[int],
|
||||
) -> dict[str, _FakeRequest]:
|
||||
requests = {}
|
||||
for rid, plen, olen in zip(req_ids, prompt_lens, output_lens):
|
||||
requests[rid] = _FakeRequest(list(range(plen)), list(range(1000, 1000 + olen)))
|
||||
return requests
|
||||
|
||||
|
||||
def _backup_buggy(
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
requests: dict[str, _FakeRequest],
|
||||
batch: _FakeInputBatch,
|
||||
) -> list[int]:
|
||||
"""Old logic: uses seq_lens_cpu directly (may be inflated)."""
|
||||
n = batch.num_reqs
|
||||
return [
|
||||
requests[batch.req_ids[i]].get_token_id(int(seq_lens_cpu[i])) for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
def _backup_fixed(
|
||||
requests: dict[str, _FakeRequest],
|
||||
batch: _FakeInputBatch,
|
||||
) -> list[int]:
|
||||
"""New logic: uses num_tokens_no_spec - 1 (last committed token)."""
|
||||
n = batch.num_reqs
|
||||
idx = (batch.num_tokens_no_spec[:n] - 1).tolist()
|
||||
return [requests[batch.req_ids[i]].get_token_id(int(idx[i])) for i in range(n)]
|
||||
|
||||
|
||||
class TestBackupTokenAsyncSpec:
|
||||
def test_no_inflation_fixed_returns_last_token(self):
|
||||
req_ids = ["r0", "r1"]
|
||||
requests = _make_requests(req_ids, [3, 3], [2, 2])
|
||||
batch = _FakeInputBatch(req_ids, [5, 5])
|
||||
# idx = 5-1 = 4 → output[1] = 1001
|
||||
assert _backup_fixed(requests, batch) == [1001, 1001]
|
||||
|
||||
def test_inflation_buggy_returns_placeholder(self):
|
||||
req_ids = ["r0", "r1"]
|
||||
requests = _make_requests(req_ids, [3, 3], [2, 2])
|
||||
batch = _FakeInputBatch(req_ids, [5, 5])
|
||||
# inflated by 3 spec tokens → idx 8 is out of range
|
||||
seq_lens = torch.tensor([8, 8], dtype=torch.int64)
|
||||
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1]
|
||||
|
||||
def test_inflation_fixed_returns_correct_token(self):
|
||||
req_ids = ["r0", "r1"]
|
||||
requests = _make_requests(req_ids, [3, 3], [2, 2])
|
||||
batch = _FakeInputBatch(req_ids, [5, 5])
|
||||
assert _backup_fixed(requests, batch) == [1001, 1001]
|
||||
|
||||
def test_mixed_inflation_per_request(self):
|
||||
req_ids = ["r0", "r1", "r2"]
|
||||
requests = {
|
||||
"r0": _FakeRequest([0, 1], [1000, 1001, 1002]),
|
||||
"r1": _FakeRequest([0, 1, 2, 3], [2000]),
|
||||
"r2": _FakeRequest([0], [3000, 3001, 3002, 3003]),
|
||||
}
|
||||
batch = _FakeInputBatch(req_ids, [5, 5, 5])
|
||||
seq_lens = torch.tensor([7, 9, 5], dtype=torch.int64)
|
||||
|
||||
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1, -1]
|
||||
assert _backup_fixed(requests, batch) == [1002, 2000, 3003]
|
||||
|
||||
def test_prefill_only_request(self):
|
||||
"""No output tokens yet — backup should be the last prompt token."""
|
||||
req_ids = ["r0"]
|
||||
requests = {"r0": _FakeRequest([10, 20, 30], [])}
|
||||
batch = _FakeInputBatch(req_ids, [3])
|
||||
# idx = 3-1 = 2 → prompt[2] = 30
|
||||
assert _backup_fixed(requests, batch) == [30]
|
||||
|
||||
@pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 4, 5])
|
||||
def test_various_spec_token_counts(self, num_spec_tokens: int):
|
||||
req_ids = ["r0"]
|
||||
requests = {"r0": _FakeRequest([0, 1, 2], list(range(1000, 1005)))}
|
||||
batch = _FakeInputBatch(req_ids, [8])
|
||||
# idx = 8-1 = 7 → output[4] = 1004
|
||||
assert _backup_fixed(requests, batch) == [1004]
|
||||
|
||||
def test_buggy_code_was_always_off_by_one(self):
|
||||
"""The original code used seq_len as index, which is always one past
|
||||
the end of output_token_ids even without async inflation."""
|
||||
req_ids = ["r0"]
|
||||
requests = {"r0": _FakeRequest([0, 1, 2], [1000, 1001])}
|
||||
batch = _FakeInputBatch(req_ids, [5])
|
||||
|
||||
# no inflation: seq_len == num_tokens == 5 → idx 5 is out of range
|
||||
seq_lens = torch.tensor([5], dtype=torch.int64)
|
||||
assert _backup_buggy(seq_lens, requests, batch) == [-1]
|
||||
assert _backup_fixed(requests, batch) == [1001]
|
||||
|
||||
# with inflation: still -1, fixed still correct
|
||||
seq_lens_inf = torch.tensor([8], dtype=torch.int64)
|
||||
assert _backup_buggy(seq_lens_inf, requests, batch) == [-1]
|
||||
assert _backup_fixed(requests, batch) == [1001]
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -111,16 +112,14 @@ def test_prepare_next_token_ids():
|
||||
|
||||
num_requests = 4
|
||||
num_speculative_tokens = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
query_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
mock_input_batch.num_tokens_no_spec = np.array(
|
||||
[num_speculative_tokens + 1] * num_requests
|
||||
)
|
||||
|
||||
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
|
||||
mock_requests = {}
|
||||
@@ -165,19 +164,12 @@ def test_prepare_next_token_ids():
|
||||
|
||||
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=BLOCK_SIZE,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
[2, 5, 0, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = (
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids_tensor,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded():
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[5] * num_requests,
|
||||
query_lens=[5] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests)
|
||||
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
@@ -174,12 +171,6 @@ def test_prepare_next_token_ids_padded():
|
||||
|
||||
proposer = _create_proposer(num_speculative_tokens=1)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
|
||||
# It doesn't depend on whether the request is discarded
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
@@ -187,7 +178,6 @@ def test_prepare_next_token_ids_padded():
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
@@ -657,7 +657,11 @@ class VllmConfig:
|
||||
)
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
self.kv_transfer_config.kv_connector = "OffloadingConnector"
|
||||
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
|
||||
config_connector = "SimpleCPUOffloadConnector"
|
||||
else:
|
||||
config_connector = "OffloadingConnector"
|
||||
self.kv_transfer_config.kv_connector = config_connector
|
||||
self.kv_transfer_config.kv_connector_extra_config.update(
|
||||
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
|
||||
)
|
||||
|
||||
@@ -202,6 +202,7 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
|
||||
@@ -213,3 +214,9 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
|
||||
"FlexKVConnectorV1",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"SimpleCPUOffloadConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
|
||||
"SimpleCPUOffloadConnector",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""SimpleCPUOffloadConnector: minimal CPU KV cache offloading."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.simple_kv_offload.manager import (
|
||||
SimpleCPUOffloadScheduler,
|
||||
)
|
||||
from vllm.v1.simple_kv_offload.metadata import (
|
||||
SimpleCPUOffloadMetadata,
|
||||
)
|
||||
from vllm.v1.simple_kv_offload.worker import (
|
||||
SimpleCPUOffloadWorker,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default CPU capacity: 8 GB
|
||||
DEFAULT_CPU_CAPACITY_BYTES = 8 * (1024**3)
|
||||
|
||||
|
||||
class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
"""CPU KV cache offloading with custom kernel transfers and BlockPool LRU."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching
|
||||
extra_config = self._kv_transfer_config.kv_connector_extra_config or {}
|
||||
|
||||
cpu_capacity_bytes = int(
|
||||
extra_config.get("cpu_bytes_to_use", DEFAULT_CPU_CAPACITY_BYTES)
|
||||
)
|
||||
# cpu_bytes_to_use is server-wide for compatibility;
|
||||
# cpu_bytes_to_use_per_rank overrides for per-rank capacity.
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
cpu_capacity_per_rank = cpu_capacity_bytes // world_size
|
||||
if "cpu_bytes_to_use_per_rank" in extra_config:
|
||||
explicit = int(extra_config["cpu_bytes_to_use_per_rank"])
|
||||
if explicit != cpu_capacity_per_rank:
|
||||
logger.warning(
|
||||
"cpu_bytes_to_use_per_rank (%.2f GB) != "
|
||||
"cpu_bytes_to_use/world_size (%.2f GB). Using per-rank value.",
|
||||
explicit / (1024**3),
|
||||
cpu_capacity_per_rank / (1024**3),
|
||||
)
|
||||
cpu_capacity_per_rank = explicit
|
||||
|
||||
lazy_offload = bool(extra_config.get("lazy_offload", False))
|
||||
|
||||
self.scheduler_manager: SimpleCPUOffloadScheduler | None = None
|
||||
self.worker_handler: SimpleCPUOffloadWorker | None = None
|
||||
|
||||
if not enable_prefix_caching:
|
||||
logger.warning(
|
||||
"Detected prefix caching disabled, disabling CPU offload "
|
||||
"since it requires prefix caching."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"SimpleCPUOffloadConnector: role=%s, "
|
||||
"per_rank=%.2f GB, world_size=%d, mode=%s",
|
||||
role.name,
|
||||
cpu_capacity_per_rank / (1024**3),
|
||||
world_size,
|
||||
"lazy" if lazy_offload else "eager",
|
||||
)
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.scheduler_manager = SimpleCPUOffloadScheduler(
|
||||
vllm_config,
|
||||
kv_cache_config,
|
||||
cpu_capacity_per_rank,
|
||||
lazy_offload=lazy_offload,
|
||||
)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.worker_handler = SimpleCPUOffloadWorker(
|
||||
vllm_config, kv_cache_config, cpu_capacity_per_rank
|
||||
)
|
||||
|
||||
# --- Worker-side methods ---
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
if self.worker_handler is not None:
|
||||
self.worker_handler.register_kv_caches(kv_caches)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self,
|
||||
connector_metadata: KVConnectorMetadata,
|
||||
) -> None:
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
if self.worker_handler is not None:
|
||||
assert isinstance(connector_metadata, SimpleCPUOffloadMetadata)
|
||||
self.worker_handler.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
super().clear_connector_metadata()
|
||||
if self.worker_handler is not None:
|
||||
self.worker_handler.clear_connector_metadata()
|
||||
|
||||
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.worker_handler is not None:
|
||||
assert isinstance(kv_connector_metadata, SimpleCPUOffloadMetadata)
|
||||
self.worker_handler.handle_preemptions(kv_connector_metadata)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
pass # Launch loads ops in get_finished() after launching model execution
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass # Always load asynchronously and deferred to get_finished()
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass # Always save asynchronously and deferred to get_finished()
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
pass # All stores are driven by get_finished() and no wait needed
|
||||
|
||||
def get_finished(
|
||||
self,
|
||||
finished_req_ids: set[str],
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if self.worker_handler is not None:
|
||||
return self.worker_handler.get_finished(finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def build_connector_worker_meta(self):
|
||||
if self.worker_handler is not None:
|
||||
return self.worker_handler.build_connector_worker_meta()
|
||||
return None
|
||||
|
||||
# --- Scheduler-side methods ---
|
||||
|
||||
# NOTE: New API only for SimpleCPUOffloadConnector.
|
||||
def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None:
|
||||
if self.scheduler_manager is not None:
|
||||
self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
) -> None:
|
||||
if self.scheduler_manager is not None:
|
||||
self.scheduler_manager.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.build_connector_meta(scheduler_output)
|
||||
return SimpleCPUOffloadMetadata()
|
||||
|
||||
def update_connector_output(
|
||||
self,
|
||||
connector_output: KVConnectorOutput,
|
||||
) -> None:
|
||||
if self.scheduler_manager is not None:
|
||||
self.scheduler_manager.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.request_finished(request, block_ids)
|
||||
return False, None
|
||||
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.request_finished_all_groups(
|
||||
request, block_ids
|
||||
)
|
||||
return False, None
|
||||
|
||||
# NOTE: New API only for SimpleCPUOffloadConnector.
|
||||
def has_pending_transfers(self) -> bool:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.has_pending_stores()
|
||||
return False
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.take_events()
|
||||
return []
|
||||
|
||||
def reset_cache(self) -> bool | None:
|
||||
raise NotImplementedError(
|
||||
"SimpleCPUOffloadConnector does not support reset_cache(). "
|
||||
"reset_prefix_cache() requires synchronizing all pending "
|
||||
"CPU offload transfers before clearing GPU prefix cache blocks, "
|
||||
"which is not yet implemented."
|
||||
)
|
||||
11
vllm/envs.py
11
vllm/envs.py
@@ -54,6 +54,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
@@ -842,6 +843,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Enable SPMD mode for TPU backend.
|
||||
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
|
||||
# Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks.
|
||||
# Bounds the [M, N] float32 logits tensor to prevent CUDA OOM.
|
||||
# Default: 512 MB
|
||||
"VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int(
|
||||
os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512")
|
||||
),
|
||||
# If set, the OpenAI API server will stay alive even after the underlying
|
||||
# AsyncLLMEngine errors and stops serving requests
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
|
||||
@@ -1655,6 +1662,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool(
|
||||
int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0"))
|
||||
),
|
||||
# Enable simple KV offload.
|
||||
"VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@@ -51,6 +52,14 @@ def sparse_attn_indexer(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
# Dummy allocation to simulate for peak logits tensor memory during inference.
|
||||
# FP8 elements so elements == bytes
|
||||
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
_ = torch.empty(
|
||||
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
|
||||
)
|
||||
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
@@ -101,13 +110,16 @@ def sparse_attn_indexer(
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
if not chunk.skip_kv_gather:
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
|
||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -22,7 +23,6 @@ from vllm.v1.attention.backend import (
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
@@ -30,6 +30,55 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def split_indexer_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
query_lens_cpu: torch.Tensor,
|
||||
workspace_size: int,
|
||||
max_logits_bytes: int,
|
||||
request_offset: int = 0,
|
||||
) -> list[tuple[slice, slice]]:
|
||||
"""
|
||||
Split prefill requests into chunks for the sparse indexer, respecting:
|
||||
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
|
||||
- Logits constraint: M * N * 4 <= max_logits_bytes
|
||||
|
||||
When a single request-level chunk still exceeds the logits budget,
|
||||
sub-chunks on the query dimension (M) to bound peak memory.
|
||||
|
||||
Returns list of (req_slice, query_slice) tuples.
|
||||
"""
|
||||
chunks: list[tuple[slice, slice]] = []
|
||||
n = len(seq_lens_cpu)
|
||||
max_logits_elems = max_logits_bytes // 4
|
||||
end = 0
|
||||
|
||||
while end < n:
|
||||
start, chunk_m, chunk_n = end, 0, 0
|
||||
|
||||
while end < n:
|
||||
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
new_m, new_n = chunk_m + q, chunk_n + s
|
||||
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
|
||||
chunk_m, chunk_n = new_m, new_n
|
||||
end += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# A single request can exceed the budget, requiring sub-chunking
|
||||
# on the query dimension.
|
||||
if end == start:
|
||||
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
end += 1
|
||||
|
||||
req_slice = slice(start + request_offset, end + request_offset)
|
||||
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
|
||||
for q_off in range(0, chunk_m, max_q):
|
||||
sub_m = min(max_q, chunk_m - q_off)
|
||||
chunks.append((req_slice, slice(q_off, q_off + sub_m)))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
skip_kv_gather: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -271,43 +321,51 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
self,
|
||||
req_slice: slice,
|
||||
query_slice: slice,
|
||||
query_start_loc_cpu,
|
||||
seq_lens_cpu,
|
||||
block_table,
|
||||
skip_kv_gather: bool = False,
|
||||
) -> DeepseekV32IndexerPrefillChunkMetadata:
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
|
||||
- query_start_loc_cpu[req_slice.start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[req_slice.start].item()
|
||||
total_seq_lens = seq_lens_cpu[req_slice].sum()
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
|
||||
self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(
|
||||
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
||||
).to(self.device)
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
seq_lens_cpu[req_slice].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlen_ks=cu_seqlen_ks[query_slice],
|
||||
cu_seqlen_ke=cu_seqlen_ke[query_slice],
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
block_table=block_table[req_slice],
|
||||
token_start=token_start + query_slice.start,
|
||||
token_end=token_start + query_slice.stop,
|
||||
num_reqs=num_reqs,
|
||||
skip_kv_gather=skip_kv_gather,
|
||||
)
|
||||
|
||||
def build(
|
||||
@@ -333,20 +391,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
prefill_query_lens_cpu = torch.diff(
|
||||
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
|
||||
)
|
||||
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
chunk_specs = split_indexer_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
||||
prefill_query_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
max_logits_bytes,
|
||||
request_offset=num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
req_slice,
|
||||
query_slice,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
skip_kv_gather=query_slice.start > 0,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
for req_slice, query_slice in chunk_specs
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
|
||||
@@ -234,6 +234,13 @@ class Scheduler(SchedulerInterface):
|
||||
hash_block_size=self.block_size,
|
||||
metrics_collector=self.kv_metrics_collector,
|
||||
)
|
||||
# Bind GPU block pool to the KV connector. This must happen after
|
||||
# kv_cache_manager is constructed so block_pool is available.
|
||||
if self.connector is not None and hasattr(
|
||||
self.connector, "bind_gpu_block_pool"
|
||||
):
|
||||
self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool)
|
||||
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
self.scheduler_reserve_full_isl = (
|
||||
|
||||
@@ -281,8 +281,16 @@ class PromptTokenStats:
|
||||
|
||||
self.computed += prompt_len - num_cached_tokens
|
||||
self.external_kv_transfer += num_external_computed_tokens
|
||||
self.local_cache_hit += (
|
||||
num_cached_tokens + recomputed - num_external_computed_tokens
|
||||
# FIXME(yifan): local_cache_hit can go negative after preemption.
|
||||
# num_cached_tokens is a one-time snapshot from first scheduling and
|
||||
# is never reset on preemption, while num_external_computed_tokens is
|
||||
# overwritten on re-scheduling. If CPU offload finds more tokens on
|
||||
# the second pass than the original total, the subtraction underflows.
|
||||
# A fundamental fix is to track the first-time num_external_computed_tokens
|
||||
# as a separate metric rather than reusing num_external_computed_tokens
|
||||
# for metric directly.
|
||||
self.local_cache_hit += max(
|
||||
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
|
||||
)
|
||||
self.cached_tokens += num_cached_tokens
|
||||
self.recomputed_tokens += recomputed
|
||||
|
||||
0
vllm/v1/simple_kv_offload/__init__.py
Normal file
0
vllm/v1/simple_kv_offload/__init__.py
Normal file
97
vllm/v1/simple_kv_offload/copy_backend.py
Normal file
97
vllm/v1/simple_kv_offload/copy_backend.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""DMA copy backend for GPU<->CPU block transfers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.simple_kv_offload.cuda_mem_ops import (
|
||||
BatchMemcpyParams,
|
||||
build_params,
|
||||
copy_blocks,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DmaCopyBackend:
|
||||
"""cuMemcpyBatchAsync copy backend (background thread)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store_params: BatchMemcpyParams | None = None
|
||||
self._load_params: BatchMemcpyParams | None = None
|
||||
self._load_stream: torch.cuda.Stream | None = None
|
||||
self._store_stream: torch.cuda.Stream | None = None
|
||||
self._queue: queue.SimpleQueue | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._shutdown: bool = False
|
||||
|
||||
def init(
|
||||
self,
|
||||
gpu_caches: dict[str, torch.Tensor],
|
||||
cpu_caches: dict[str, torch.Tensor],
|
||||
device: torch.device,
|
||||
load_stream: torch.cuda.Stream,
|
||||
store_stream: torch.cuda.Stream,
|
||||
) -> None:
|
||||
self._load_stream = load_stream
|
||||
self._store_stream = store_stream
|
||||
|
||||
self._store_params = build_params(gpu_caches, cpu_caches, store_stream)
|
||||
self._load_params = build_params(cpu_caches, gpu_caches, load_stream)
|
||||
|
||||
self._queue = queue.SimpleQueue()
|
||||
self._thread = threading.Thread(
|
||||
target=self._copy_loop,
|
||||
args=(self._queue, device, load_stream, store_stream),
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def launch_copy(
|
||||
self,
|
||||
src_blocks: list[int],
|
||||
dst_blocks: list[int],
|
||||
is_store: bool,
|
||||
event_idx: int,
|
||||
events_list: list[tuple[int, torch.Event]],
|
||||
) -> None:
|
||||
params = self._store_params if is_store else self._load_params
|
||||
assert params is not None and self._queue is not None
|
||||
self._queue.put(
|
||||
(src_blocks, dst_blocks, params, is_store, event_idx, events_list)
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._shutdown:
|
||||
return
|
||||
self._shutdown = True
|
||||
if self._queue is not None:
|
||||
self._queue.put(None)
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
@staticmethod
|
||||
def _copy_loop(
|
||||
q: queue.SimpleQueue,
|
||||
device: torch.device,
|
||||
load_stream: torch.cuda.Stream,
|
||||
store_stream: torch.cuda.Stream,
|
||||
) -> None:
|
||||
current_platform.set_device(device)
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
return
|
||||
src_blocks, dst_blocks, params, is_store, event_idx, events_list = item
|
||||
copy_blocks(src_blocks, dst_blocks, params)
|
||||
stream = store_stream if is_store else load_stream
|
||||
event = torch.Event()
|
||||
event.record(stream)
|
||||
events_list.append((event_idx, event))
|
||||
153
vllm/v1/simple_kv_offload/cuda_mem_ops.py
Normal file
153
vllm/v1/simple_kv_offload/cuda_mem_ops.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Low-level CUDA memory helpers: pinning and batch DMA transfers."""
|
||||
|
||||
import ctypes
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def pin_tensor(tensor: torch.Tensor) -> None:
|
||||
"""Pin a CPU tensor via cudaHostRegister.
|
||||
|
||||
This bypasses PyTorch's CUDACachingHostAllocator which rounds
|
||||
every ``pin_memory=True`` allocation up to the next power of 2
|
||||
(e.g. 100 GB becomes 128 GB).
|
||||
"""
|
||||
err = torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.nbytes, 0)
|
||||
if err.value != 0:
|
||||
raise RuntimeError(f"cudaHostRegister failed: {err}")
|
||||
|
||||
|
||||
class _CUmemLocation(ctypes.Structure):
|
||||
_fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)]
|
||||
|
||||
|
||||
class _CUmemcpyAttributes(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("srcAccessOrder", ctypes.c_uint),
|
||||
("srcLocHint", _CUmemLocation),
|
||||
("dstLocHint", _CUmemLocation),
|
||||
("flags", ctypes.c_uint),
|
||||
]
|
||||
|
||||
|
||||
_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE(
|
||||
ctypes.c_uint, # CUresult
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
)
|
||||
|
||||
# Resolved lazily on first use.
|
||||
_batch_memcpy_fn: Any = None
|
||||
|
||||
|
||||
def _resolve_batch_memcpy():
|
||||
"""Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time)."""
|
||||
from cuda.bindings import driver as drv
|
||||
|
||||
err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0)
|
||||
if err != drv.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"cuGetProcAddress(cuMemcpyBatchAsync) failed: {err}")
|
||||
return _BATCH_MEMCPY_FUNC_TYPE(ptr)
|
||||
|
||||
|
||||
class BatchMemcpyParams(NamedTuple):
|
||||
src_bases: np.ndarray # [num_layers] uint64 — data_ptr per layer
|
||||
dst_bases: np.ndarray # [num_layers] uint64
|
||||
bpb: np.ndarray # [num_layers] uint64 — bytes per block
|
||||
num_layers: int
|
||||
attrs: _CUmemcpyAttributes
|
||||
attrs_idx: ctypes.c_size_t
|
||||
# NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use
|
||||
# cuMemcpyBatchAsync() with fail_idx for backward compatibility
|
||||
fail_idx: ctypes.c_size_t
|
||||
stream_handle: int # raw cudaStream_t / CUstream
|
||||
|
||||
|
||||
def build_params(
|
||||
src_caches: dict[str, torch.Tensor],
|
||||
dst_caches: dict[str, torch.Tensor],
|
||||
stream: torch.cuda.Stream,
|
||||
) -> BatchMemcpyParams:
|
||||
global _batch_memcpy_fn
|
||||
if _batch_memcpy_fn is None:
|
||||
_batch_memcpy_fn = _resolve_batch_memcpy()
|
||||
|
||||
assert list(src_caches.keys()) == list(dst_caches.keys())
|
||||
src_tensors = list(src_caches.values())
|
||||
dst_tensors = list(dst_caches.values())
|
||||
|
||||
src_bases, dst_bases, bpb = [], [], []
|
||||
for s, d in zip(src_tensors, dst_tensors):
|
||||
s_bpb = s.stride(0) * s.element_size()
|
||||
assert s_bpb == d.stride(0) * d.element_size()
|
||||
src_bases.append(s.data_ptr())
|
||||
dst_bases.append(d.data_ptr())
|
||||
bpb.append(s_bpb)
|
||||
|
||||
# Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501
|
||||
attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY
|
||||
|
||||
return BatchMemcpyParams(
|
||||
src_bases=np.array(src_bases, dtype=np.uint64),
|
||||
dst_bases=np.array(dst_bases, dtype=np.uint64),
|
||||
bpb=np.array(bpb, dtype=np.uint64),
|
||||
num_layers=len(src_tensors),
|
||||
attrs=attrs,
|
||||
attrs_idx=ctypes.c_size_t(0),
|
||||
fail_idx=ctypes.c_size_t(0),
|
||||
stream_handle=stream.cuda_stream,
|
||||
)
|
||||
|
||||
|
||||
def copy_blocks(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
params: BatchMemcpyParams,
|
||||
) -> None:
|
||||
"""Copy blocks via cuMemcpyBatchAsync."""
|
||||
n = len(src_block_ids)
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
src_ids = np.array(src_block_ids, dtype=np.uint64)
|
||||
dst_ids = np.array(dst_block_ids, dtype=np.uint64)
|
||||
|
||||
src_all = (
|
||||
params.src_bases[:, None] + src_ids[None, :] * params.bpb[:, None]
|
||||
).ravel()
|
||||
dst_all = (
|
||||
params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None]
|
||||
).ravel()
|
||||
sz_all = np.repeat(params.bpb, n)
|
||||
|
||||
total = n * params.num_layers
|
||||
err = _batch_memcpy_fn(
|
||||
dst_all.ctypes.data,
|
||||
src_all.ctypes.data,
|
||||
sz_all.ctypes.data,
|
||||
total,
|
||||
ctypes.addressof(params.attrs),
|
||||
ctypes.byref(params.attrs_idx),
|
||||
1,
|
||||
ctypes.byref(params.fail_idx),
|
||||
params.stream_handle,
|
||||
)
|
||||
if err != 0:
|
||||
raise RuntimeError(
|
||||
f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}"
|
||||
)
|
||||
739
vllm/v1/simple_kv_offload/manager.py
Normal file
739
vllm/v1/simple_kv_offload/manager.py
Normal file
@@ -0,0 +1,739 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Scheduler-side manager for SimpleCPUOffloadConnector."""
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_coordinator import (
|
||||
KVCacheCoordinator,
|
||||
get_kv_cache_coordinator,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
MambaSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.simple_kv_offload.metadata import (
|
||||
SimpleCPUOffloadMetadata,
|
||||
SimpleCPUOffloadWorkerMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadRequestState:
|
||||
request: "Request"
|
||||
transfer_meta: TransferMeta
|
||||
load_event: int | None = None
|
||||
finished: bool = False
|
||||
|
||||
|
||||
# NOTE: This per-request state is only used in eager mode.
|
||||
@dataclass
|
||||
class StoreRequestState:
|
||||
request: "Request"
|
||||
# Accumulated block IDs from scheduler_output via yield_req_data.
|
||||
block_ids: tuple[list[int], ...]
|
||||
# Per-group cursors tracking how many blocks have been stored/skipped.
|
||||
num_stored_blocks: list[int]
|
||||
store_events: set[int] = field(default_factory=set)
|
||||
finished: bool = False
|
||||
|
||||
|
||||
class SimpleCPUOffloadScheduler:
|
||||
"""Scheduler-side manager for CPU offloading."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: "KVCacheConfig | None",
|
||||
cpu_capacity_bytes: int,
|
||||
lazy_offload: bool = False,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.enable_kv_cache_events = (
|
||||
vllm_config.kv_events_config is not None
|
||||
and vllm_config.kv_events_config.enable_kv_cache_events
|
||||
)
|
||||
# NOTE: We use the same block size for both GPU and CPU.
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# Derive a CPU KVCacheConfig from the GPU config and build a coordinator
|
||||
assert kv_cache_config is not None
|
||||
self.cpu_kv_cache_config = self._derive_cpu_config(
|
||||
kv_cache_config, cpu_capacity_bytes
|
||||
)
|
||||
self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks
|
||||
# Find the full attention kv group for prefix cache matching.
|
||||
self.fa_gidx = -1
|
||||
for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
self.fa_gidx = g_idx
|
||||
break
|
||||
assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups)
|
||||
|
||||
logger.info(
|
||||
"SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)",
|
||||
self.num_cpu_blocks,
|
||||
cpu_capacity_bytes / (1024**3),
|
||||
"lazy" if lazy_offload else "eager",
|
||||
)
|
||||
|
||||
# TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here.
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
assert dcp_world_size == 1 and pcp_world_size == 1
|
||||
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=self.cpu_kv_cache_config,
|
||||
max_model_len=vllm_config.model_config.max_model_len,
|
||||
use_eagle=False,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=self.block_size,
|
||||
)
|
||||
self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool
|
||||
|
||||
# GPU block pool reference - bound after scheduler builds kv_cache_manager
|
||||
self._gpu_block_pool: BlockPool | None = None
|
||||
|
||||
# Load metadata
|
||||
self._reqs_to_load: dict[str, LoadRequestState] = {}
|
||||
# Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because
|
||||
# the worker reports completions by event index, not request id.
|
||||
self._load_event_to_reqs: dict[int, list[str]] = {}
|
||||
|
||||
# Store metadata
|
||||
self._lazy_mode = lazy_offload
|
||||
# Lazy mode: use a cursor to track the last scanned block in the GPU free queue.
|
||||
self._cursor: KVCacheBlock | None = None
|
||||
if self._lazy_mode:
|
||||
self._target_free = self._estimate_lazy_target_blocks(
|
||||
kv_cache_config,
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
else:
|
||||
self._target_free = 0
|
||||
self._store_event_to_blocks: dict[int, TransferMeta] = {}
|
||||
# Eager mode only
|
||||
self._reqs_to_store: dict[str, StoreRequestState] = {}
|
||||
self._store_event_to_reqs: dict[int, list[str]] = {}
|
||||
|
||||
# Event counters
|
||||
self._load_event_counter: int = 0
|
||||
self._store_event_counter: int = 0
|
||||
|
||||
# For TP/PP: track partial store completions across steps.
|
||||
# Events must be reported by all world_size workers before considered complete.
|
||||
self._expected_worker_count = vllm_config.parallel_config.world_size
|
||||
self._store_event_pending_counts: dict[int, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def _derive_cpu_config(
|
||||
gpu_config: "KVCacheConfig", cpu_capacity_bytes: int
|
||||
) -> "KVCacheConfig":
|
||||
"""Derive a CPU KVCacheConfig from the GPU config.
|
||||
Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio."""
|
||||
# Import here to avoid potential circular imports
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls
|
||||
from vllm.v1.kv_cache_interface import KVCacheTensor
|
||||
|
||||
assert len(gpu_config.kv_cache_tensors) > 0
|
||||
|
||||
gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors)
|
||||
num_gpu_blocks = gpu_config.num_blocks
|
||||
num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes)
|
||||
# Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally.
|
||||
cpu_tensors = [
|
||||
KVCacheTensor(
|
||||
size=t.size // num_gpu_blocks * num_cpu_blocks,
|
||||
shared_by=list(t.shared_by),
|
||||
)
|
||||
for t in gpu_config.kv_cache_tensors
|
||||
]
|
||||
|
||||
return KVCacheConfigCls(
|
||||
num_blocks=num_cpu_blocks,
|
||||
kv_cache_tensors=cpu_tensors,
|
||||
kv_cache_groups=gpu_config.kv_cache_groups,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _estimate_lazy_target_blocks(
|
||||
kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int
|
||||
) -> int:
|
||||
"""GPU blocks to keep available (free/offloaded) per step in lazy mode."""
|
||||
WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks
|
||||
target = 0
|
||||
for g in kv_cache_config.kv_cache_groups:
|
||||
spec = g.kv_cache_spec
|
||||
if isinstance(spec, MambaSpec):
|
||||
target += 2
|
||||
elif isinstance(spec, SlidingWindowSpec):
|
||||
target += cdiv(spec.sliding_window, spec.block_size) + 1
|
||||
else:
|
||||
target += cdiv(max_num_batched_tokens, spec.block_size)
|
||||
return int(target * (1 + WATERMARK_RATIO))
|
||||
|
||||
def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None:
|
||||
"""Bind GPU block pool so that we can touch blocks during stores.
|
||||
Called by Scheduler after kv_cache_manager is ready."""
|
||||
self._gpu_block_pool = gpu_block_pool
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
"""Return (num_new_tokens, is_async) from consecutive CPU cache hits."""
|
||||
skipped = num_computed_tokens // self.block_size
|
||||
remaining_hashes = request.block_hashes[skipped:]
|
||||
|
||||
if not remaining_hashes:
|
||||
return 0, False
|
||||
# Must recompute at least the last token, matching the logic in
|
||||
# kv_cache_manager.get_computed_blocks().
|
||||
max_hit_len = request.num_tokens - 1 - num_computed_tokens
|
||||
if max_hit_len <= 0:
|
||||
return 0, False
|
||||
_, hit_length = self.cpu_coordinator.find_longest_cache_hit(
|
||||
remaining_hashes, max_hit_len
|
||||
)
|
||||
|
||||
if hit_length > 0:
|
||||
return hit_length, True
|
||||
return 0, False
|
||||
|
||||
# TODO(yifan): this API now only matches the suffix part of the prefix cache. A more
|
||||
# general API should scan blocks in both GPU and CPU block pool in a single pass.
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
) -> None:
|
||||
req_id = request.request_id
|
||||
block_ids_by_group = blocks.get_block_ids()
|
||||
num_groups = len(block_ids_by_group)
|
||||
|
||||
# Store tracking (eager mode only). Register the request;
|
||||
# block IDs are accumulated from scheduler_output in
|
||||
# _prepare_eager_store_specs via yield_req_data.
|
||||
if not self._lazy_mode and req_id not in self._reqs_to_store:
|
||||
self._reqs_to_store[req_id] = StoreRequestState(
|
||||
request=request,
|
||||
block_ids=tuple([] for _ in range(num_groups)),
|
||||
num_stored_blocks=[0] * num_groups,
|
||||
)
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
num_blocks_to_load = num_external_tokens // self.block_size
|
||||
assert num_blocks_to_load > 0
|
||||
|
||||
skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx])
|
||||
num_computed_tokens = skipped * self.block_size
|
||||
hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load]
|
||||
|
||||
# Find CPU cached blocks across all groups.
|
||||
max_hit_len = len(hashes_to_load) * self.block_size
|
||||
cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit(
|
||||
hashes_to_load, max_hit_len
|
||||
)
|
||||
assert hit_length == num_external_tokens, (
|
||||
f"Expected {num_external_tokens} hit tokens, got {hit_length}"
|
||||
)
|
||||
|
||||
# Build transfer pairs across all groups.
|
||||
total_computed_tokens = num_computed_tokens + num_external_tokens
|
||||
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
|
||||
|
||||
gpu_block_ids: list[int] = []
|
||||
cpu_block_ids: list[int] = []
|
||||
cpu_blocks_to_touch: list[KVCacheBlock] = []
|
||||
|
||||
for g in range(num_groups):
|
||||
cpu_blocks_g = cpu_hit_blocks[g]
|
||||
n_ext_g = len(cpu_blocks_g)
|
||||
if n_ext_g == 0:
|
||||
continue
|
||||
|
||||
# Number of blocks in the computed range for this group.
|
||||
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
|
||||
n_computed_g = cdiv(total_computed_tokens, g_block_size)
|
||||
|
||||
# Back-trace: ext blocks sit at the tail of the computed range.
|
||||
gpu_ext_start = n_computed_g - n_ext_g
|
||||
group_gpu_ids = block_ids_by_group[g]
|
||||
|
||||
for i, cpu_blk in enumerate(cpu_blocks_g):
|
||||
# Skip null blocks (e.g. sliding window or mamba padding).
|
||||
if cpu_blk.is_null:
|
||||
continue
|
||||
gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i])
|
||||
cpu_block_ids.append(cpu_blk.block_id)
|
||||
cpu_blocks_to_touch.append(cpu_blk)
|
||||
|
||||
# Touch CPU blocks to prevent eviction during async load.
|
||||
self.cpu_block_pool.touch(cpu_blocks_to_touch)
|
||||
|
||||
# Touch GPU blocks to prevent freeing during async load
|
||||
assert self._gpu_block_pool is not None
|
||||
self._gpu_block_pool.touch(
|
||||
[self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
|
||||
)
|
||||
|
||||
assert self._reqs_to_load.get(req_id) is None
|
||||
self._reqs_to_load[req_id] = LoadRequestState(
|
||||
request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids)
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> SimpleCPUOffloadMetadata:
|
||||
# --- Stores ---
|
||||
store_event = -1
|
||||
store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output)
|
||||
if store_gpu:
|
||||
store_event = self._store_event_counter
|
||||
self._store_event_counter += 1
|
||||
self._store_event_to_blocks[store_event] = TransferMeta(
|
||||
store_gpu, store_cpu
|
||||
)
|
||||
if store_req_ids: # For eager mode only, track req->blocks mapping
|
||||
self._store_event_to_reqs[store_event] = store_req_ids
|
||||
for req_id in store_req_ids:
|
||||
store_state = self._reqs_to_store.get(req_id)
|
||||
if store_state is not None:
|
||||
store_state.store_events.add(store_event)
|
||||
|
||||
# --- Loads ---
|
||||
load_event = -1
|
||||
load_gpu: list[int] = []
|
||||
load_cpu: list[int] = []
|
||||
load_req_ids: list[str] = []
|
||||
for req_id, load_state in self._reqs_to_load.items():
|
||||
if load_state.load_event is not None:
|
||||
continue
|
||||
assert load_state.transfer_meta is not None
|
||||
load_gpu.extend(load_state.transfer_meta.gpu_block_ids)
|
||||
load_cpu.extend(load_state.transfer_meta.cpu_block_ids)
|
||||
load_req_ids.append(req_id)
|
||||
if load_req_ids:
|
||||
load_event = self._load_event_counter
|
||||
self._load_event_counter += 1
|
||||
for req_id in load_req_ids:
|
||||
self._reqs_to_load[req_id].load_event = load_event
|
||||
self._load_event_to_reqs[load_event] = load_req_ids
|
||||
|
||||
result = SimpleCPUOffloadMetadata(
|
||||
load_event=load_event,
|
||||
load_gpu_blocks=load_gpu,
|
||||
load_cpu_blocks=load_cpu,
|
||||
load_event_to_reqs=self._load_event_to_reqs,
|
||||
store_event=store_event,
|
||||
store_gpu_blocks=store_gpu,
|
||||
store_cpu_blocks=store_cpu,
|
||||
need_flush=bool(scheduler_output.preempted_req_ids),
|
||||
)
|
||||
return result
|
||||
|
||||
def prepare_store_specs(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> tuple[list[int], list[int], list[str]]:
|
||||
"""Prepare store specs for the store event."""
|
||||
if self._lazy_mode:
|
||||
return self._prepare_lazy_store_specs()
|
||||
else:
|
||||
return self._prepare_eager_store_specs(scheduler_output)
|
||||
|
||||
def _prepare_lazy_store_specs(
|
||||
self,
|
||||
) -> tuple[list[int], list[int], list[str]]:
|
||||
"""Single-pass cursor walk: offload cached GPU blocks near eviction.
|
||||
|
||||
Walks the GPU free queue from the cursor, counting blocks that are
|
||||
free-or-offloaded (safe for the allocator to evict). Stops when
|
||||
target_free blocks are covered or CPU capacity is reached.
|
||||
"""
|
||||
gpu_pool = self._gpu_block_pool
|
||||
if gpu_pool is None or self._target_free <= 0:
|
||||
return [], [], []
|
||||
|
||||
free_queue = gpu_pool.free_block_queue
|
||||
cpu_pool = self.cpu_block_pool
|
||||
num_cpu_free = cpu_pool.get_num_free_blocks()
|
||||
|
||||
# Validate cursor: stale if block was removed from free queue.
|
||||
if self._cursor is not None and self._cursor.ref_cnt > 0:
|
||||
self._cursor = None
|
||||
|
||||
# Determine start node.
|
||||
if self._cursor is None:
|
||||
node = free_queue.fake_free_list_head.next_free_block
|
||||
else:
|
||||
node = self._cursor.next_free_block
|
||||
|
||||
tail = free_queue.fake_free_list_tail
|
||||
gpu_ids: list[int] = []
|
||||
block_hashes: list[bytes] = []
|
||||
covered = 0
|
||||
last_visited = self._cursor
|
||||
|
||||
while (
|
||||
node is not None
|
||||
and node is not tail
|
||||
and covered < self._target_free
|
||||
and len(gpu_ids) < num_cpu_free
|
||||
):
|
||||
last_visited = node
|
||||
bhash = node.block_hash
|
||||
|
||||
if (
|
||||
bhash is not None
|
||||
and not node.is_null
|
||||
and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None
|
||||
):
|
||||
gpu_ids.append(node.block_id)
|
||||
block_hashes.append(bhash)
|
||||
|
||||
covered += 1
|
||||
node = node.next_free_block
|
||||
|
||||
self._cursor = last_visited
|
||||
|
||||
# Batch-allocate CPU blocks and stamp hashes.
|
||||
if gpu_ids:
|
||||
cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids))
|
||||
cpu_ids = [blk.block_id for blk in cpu_blocks]
|
||||
for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment]
|
||||
cpu_blk._block_hash = bhash # type: ignore[assignment]
|
||||
# Touch GPU blocks to prevent eviction during async copy.
|
||||
gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids])
|
||||
else:
|
||||
cpu_ids = []
|
||||
|
||||
return gpu_ids, cpu_ids, []
|
||||
|
||||
def _prepare_eager_store_specs(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> tuple[list[int], list[int], list[str]]:
|
||||
"""Identify newly computed blocks to offload from scheduler requests.
|
||||
|
||||
Only considers blocks whose KV data has been **confirmed computed** by
|
||||
the GPU. This means blocks from the current step are NOT stored until the
|
||||
next step. If a request finishes in the same step as its last full block,
|
||||
that block may be missed. (TODO: flush on finish.)
|
||||
|
||||
Returns:
|
||||
(gpu_block_ids, cpu_block_ids, req_ids) for the store event.
|
||||
"""
|
||||
|
||||
merged_gpu_block_ids: list[int] = []
|
||||
merged_cpu_block_ids: list[int] = []
|
||||
req_ids: list[str] = []
|
||||
|
||||
gpu_block_pool = self._gpu_block_pool
|
||||
if gpu_block_pool is None:
|
||||
return [], [], []
|
||||
cpu_block_pool = self.cpu_block_pool
|
||||
num_free = cpu_block_pool.get_num_free_blocks()
|
||||
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
|
||||
num_groups = len(kv_cache_groups)
|
||||
gpu_blocks_this_step: set[int] = set()
|
||||
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
state = self._reqs_to_store.get(req_id)
|
||||
if state is None or state.finished:
|
||||
continue
|
||||
|
||||
# Accumulate new block IDs.
|
||||
if preempted:
|
||||
state.block_ids = tuple([] for _ in range(num_groups))
|
||||
state.num_stored_blocks = [0] * num_groups
|
||||
if new_block_id_groups:
|
||||
for g in range(min(num_groups, len(new_block_id_groups))):
|
||||
if new_block_id_groups[g] is not None:
|
||||
state.block_ids[g].extend(new_block_id_groups[g])
|
||||
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
|
||||
if num_new_tokens == 0:
|
||||
continue
|
||||
|
||||
block_ids_by_group = state.block_ids
|
||||
if not block_ids_by_group:
|
||||
continue
|
||||
|
||||
# --- Phase 1: Scan blocks, classify as cached vs to-store ---
|
||||
gpu_block_ids: list[int] = []
|
||||
block_hashes_to_store: list[bytes] = []
|
||||
advanced_per_group: list[int] = [0] * num_groups
|
||||
out_of_space = False
|
||||
# Confirmed tokens: KV data written and visible to all streams.
|
||||
req = state.request
|
||||
confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders
|
||||
|
||||
for g in range(num_groups):
|
||||
# FIXME (yifan): handle CPU cache eviction, where
|
||||
# num_stored_blocks can be stale and omit evicted blocks in
|
||||
# the middle of the request.
|
||||
already_stored_g = state.num_stored_blocks[g]
|
||||
group_gpu_ids = block_ids_by_group[g]
|
||||
|
||||
# Cap to blocks with confirmed KV data.
|
||||
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
|
||||
ready_blocks_g = confirmed_tokens // g_block_size
|
||||
scannable = group_gpu_ids[already_stored_g:ready_blocks_g]
|
||||
|
||||
for gpu_block_id in scannable:
|
||||
gpu_block = gpu_block_pool.blocks[gpu_block_id]
|
||||
if gpu_block.is_null:
|
||||
advanced_per_group[g] += 1
|
||||
continue
|
||||
|
||||
bhash_with_group = gpu_block.block_hash
|
||||
if bhash_with_group is None:
|
||||
break
|
||||
|
||||
# Check if this group's data is already scheduled for store
|
||||
# in this step or already cached in CPU.
|
||||
if (
|
||||
gpu_block_id in gpu_blocks_this_step
|
||||
or cpu_block_pool.cached_block_hash_to_block.get_one_block(
|
||||
bhash_with_group
|
||||
)
|
||||
is not None
|
||||
):
|
||||
advanced_per_group[g] += 1
|
||||
continue
|
||||
|
||||
if num_free <= 0:
|
||||
out_of_space = True
|
||||
break
|
||||
num_free -= 1
|
||||
|
||||
gpu_block_ids.append(gpu_block_id)
|
||||
block_hashes_to_store.append(bhash_with_group)
|
||||
advanced_per_group[g] += 1
|
||||
|
||||
if out_of_space:
|
||||
break
|
||||
|
||||
# --- Phase 2: Batch allocate CPU blocks and stamp hashes ---
|
||||
n_to_alloc = len(gpu_block_ids)
|
||||
if n_to_alloc > 0:
|
||||
cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc)
|
||||
cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc]
|
||||
for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store):
|
||||
cpu_blk._block_hash = bhash # type: ignore[assignment]
|
||||
else:
|
||||
cpu_block_ids = []
|
||||
|
||||
if cpu_block_ids:
|
||||
req_ids.append(req_id)
|
||||
merged_gpu_block_ids.extend(gpu_block_ids)
|
||||
merged_cpu_block_ids.extend(cpu_block_ids)
|
||||
gpu_blocks_this_step.update(gpu_block_ids)
|
||||
|
||||
# Touch GPU blocks to prevent freeing during async copy
|
||||
gpu_block_pool.touch(
|
||||
[gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Request %s: Scheduling store of %d blocks to CPU (%d groups)",
|
||||
req_id,
|
||||
len(cpu_block_ids),
|
||||
num_groups,
|
||||
)
|
||||
|
||||
# Advance per-group cursors (includes cached hits + newly stored)
|
||||
for g in range(num_groups):
|
||||
state.num_stored_blocks[g] += advanced_per_group[g]
|
||||
|
||||
return merged_gpu_block_ids, merged_cpu_block_ids, req_ids
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput) -> None:
|
||||
"""Handle async transfer completions from worker.
|
||||
|
||||
Load completions arrive via finished_recving (real req_ids).
|
||||
Store completions arrive via kv_connector_worker_meta as
|
||||
per-event worker counts. We accumulate across steps and process
|
||||
a store event only when all workers have reported completion.
|
||||
"""
|
||||
# --- Load completions ---
|
||||
for req_id in list(connector_output.finished_recving or []):
|
||||
self._cleanup_load_request(req_id)
|
||||
|
||||
# --- Store completions ---
|
||||
meta = connector_output.kv_connector_worker_meta
|
||||
if not isinstance(meta, SimpleCPUOffloadWorkerMetadata):
|
||||
return
|
||||
for event_idx, count in meta.completed_store_events.items():
|
||||
total = self._store_event_pending_counts.get(event_idx, 0) + count
|
||||
if total >= self._expected_worker_count:
|
||||
self._store_event_pending_counts.pop(event_idx, None)
|
||||
self._process_store_event(event_idx)
|
||||
else:
|
||||
self._store_event_pending_counts[event_idx] = total
|
||||
|
||||
def _process_store_event(self, event_idx: int) -> None:
|
||||
"""Process a fully-completed store event."""
|
||||
transfer = self._store_event_to_blocks.pop(event_idx)
|
||||
self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids)
|
||||
logger.debug(
|
||||
"Store event %d completed: cached %d blocks to CPU",
|
||||
event_idx,
|
||||
len(transfer.cpu_block_ids),
|
||||
)
|
||||
|
||||
# Eager only: update per-req state
|
||||
if not self._lazy_mode:
|
||||
for req_id in self._store_event_to_reqs.pop(event_idx, []):
|
||||
state = self._reqs_to_store.get(req_id)
|
||||
if state is None:
|
||||
continue
|
||||
state.store_events.discard(event_idx)
|
||||
if state.finished and not state.store_events:
|
||||
self._cleanup_store_request(req_id)
|
||||
|
||||
def _process_store_completion(
|
||||
self, gpu_block_ids: list[int], cpu_block_ids: list[int]
|
||||
) -> None:
|
||||
"""Cache CPU blocks per-group and release GPU refs.
|
||||
|
||||
Block hashes were stamped on CPU blocks at allocation time (in
|
||||
``_prepare_*_store_specs``). Here we just register them in the
|
||||
cache map so they become discoverable by the load path.
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_block_ids)
|
||||
|
||||
cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids]
|
||||
|
||||
for cpu_block in cpu_blocks:
|
||||
bhash = cpu_block.block_hash
|
||||
assert bhash is not None
|
||||
self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block)
|
||||
|
||||
# Free CPU and GPU blocks' ref counts to turn them into prefix cache
|
||||
self.cpu_block_pool.free_blocks(cpu_blocks)
|
||||
assert self._gpu_block_pool is not None
|
||||
self._gpu_block_pool.free_blocks(
|
||||
self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids
|
||||
)
|
||||
|
||||
def has_pending_stores(self) -> bool:
|
||||
"""Return True if there are in-flight store transfers."""
|
||||
return bool(self._store_event_to_blocks)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Always returns (False, None). GPU blocks are protected by ref_cnt,
|
||||
so the scheduler can free blocks immediately."""
|
||||
req_id = request.request_id
|
||||
|
||||
# Handle load: defer cleanup if load is in-flight
|
||||
load_state = self._reqs_to_load.get(req_id)
|
||||
if load_state is not None:
|
||||
if load_state.load_event is not None:
|
||||
load_state.finished = True # Defer: load in-flight
|
||||
else:
|
||||
self._cleanup_load_request(req_id)
|
||||
|
||||
# Handle store (eager mode only): defer cleanup if stores in-flight
|
||||
if not self._lazy_mode:
|
||||
store_state = self._reqs_to_store.get(req_id)
|
||||
if store_state is not None:
|
||||
if store_state.store_events:
|
||||
store_state.finished = True # Defer: stores in-flight
|
||||
else:
|
||||
self._cleanup_store_request(req_id)
|
||||
|
||||
return False, None
|
||||
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
return self.request_finished(request, block_ids=[])
|
||||
|
||||
def _cleanup_load_request(self, req_id: str) -> None:
|
||||
"""Release all load resources for a request.
|
||||
|
||||
Shared between request_finished() and update_connector_output() paths.
|
||||
Removes the request from _reqs_to_load, cleans up event mappings,
|
||||
and frees CPU/GPU touch refs.
|
||||
"""
|
||||
state = self._reqs_to_load.pop(req_id, None)
|
||||
if state is None:
|
||||
return
|
||||
# Remove from load event mapping (only this req, not whole event)
|
||||
if state.load_event is not None:
|
||||
reqs = self._load_event_to_reqs.get(state.load_event)
|
||||
if reqs is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
reqs.remove(req_id)
|
||||
if not reqs:
|
||||
self._load_event_to_reqs.pop(state.load_event, None)
|
||||
|
||||
if state.transfer_meta is not None:
|
||||
# Free CPU touch refs
|
||||
self.cpu_block_pool.free_blocks(
|
||||
self.cpu_block_pool.blocks[bid]
|
||||
for bid in state.transfer_meta.cpu_block_ids
|
||||
)
|
||||
# Free GPU touch refs
|
||||
assert self._gpu_block_pool is not None
|
||||
self._gpu_block_pool.free_blocks(
|
||||
self._gpu_block_pool.blocks[bid]
|
||||
for bid in state.transfer_meta.gpu_block_ids
|
||||
)
|
||||
|
||||
def _cleanup_store_request(self, req_id: str) -> None:
|
||||
"""Release store metadata for a request.
|
||||
|
||||
Metadata-only cleanup but no block freeing. Job completion handles
|
||||
block caching and GPU ref freeing via _process_store_completion().
|
||||
"""
|
||||
state = self._reqs_to_store.pop(req_id, None)
|
||||
if state is None:
|
||||
return
|
||||
for event_idx in list(state.store_events):
|
||||
if (reqs := self._store_event_to_reqs.get(event_idx)) is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
reqs.remove(req_id)
|
||||
if not reqs:
|
||||
self._store_event_to_reqs.pop(event_idx, None)
|
||||
state.store_events.clear()
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
return self.cpu_block_pool.take_events()
|
||||
60
vllm/v1/simple_kv_offload/metadata.py
Normal file
60
vllm/v1/simple_kv_offload/metadata.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Metadata for SimpleCPUOffloadConnector."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
KVConnectorWorkerMetadata,
|
||||
)
|
||||
|
||||
INVALID_JOB_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleCPUOffloadMetadata(KVConnectorMetadata):
|
||||
"""
|
||||
Metadata passed from scheduler to worker for CPU offload operations.
|
||||
|
||||
The worker receives flat block lists keyed by a monotonic event_idx.
|
||||
Job->req_id translation is handled by the scheduler-side manager
|
||||
(via inverse maps), so the worker never knows about request identities.
|
||||
"""
|
||||
|
||||
# Load event per step. INVALID_JOB_ID means no blocks to load this step.
|
||||
load_event: int = INVALID_JOB_ID
|
||||
load_gpu_blocks: list[int] = field(default_factory=list)
|
||||
load_cpu_blocks: list[int] = field(default_factory=list)
|
||||
# Reverse map: load_event->req_ids, for tracking requests with finished load events
|
||||
load_event_to_reqs: dict[int, list[str]] = field(default_factory=dict)
|
||||
|
||||
# Store event per step. INVALID_JOB_ID means no blocks to store this step.
|
||||
store_event: int = INVALID_JOB_ID
|
||||
store_gpu_blocks: list[int] = field(default_factory=list)
|
||||
store_cpu_blocks: list[int] = field(default_factory=list)
|
||||
|
||||
# Whether any requests were preempted this step and need flush pending transfers.
|
||||
need_flush: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleCPUOffloadWorkerMetadata(KVConnectorWorkerMetadata):
|
||||
"""Worker -> Scheduler metadata for completed store events.
|
||||
|
||||
Each worker reports {event_idx: 1} for newly completed stores.
|
||||
``aggregate()`` sums counts across workers within a step.
|
||||
The scheduler-side manager accumulates across steps and processes
|
||||
a store completion only when count reaches ``world_size``.
|
||||
"""
|
||||
|
||||
completed_store_events: dict[int, int]
|
||||
|
||||
def aggregate(
|
||||
self, other: "KVConnectorWorkerMetadata"
|
||||
) -> "KVConnectorWorkerMetadata":
|
||||
assert isinstance(other, SimpleCPUOffloadWorkerMetadata)
|
||||
merged = dict(self.completed_store_events)
|
||||
for k, v in other.completed_store_events.items():
|
||||
merged[k] = merged.get(k, 0) + v
|
||||
return SimpleCPUOffloadWorkerMetadata(completed_store_events=merged)
|
||||
305
vllm/v1/simple_kv_offload/worker.py
Normal file
305
vllm/v1/simple_kv_offload/worker.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Worker-side handler for SimpleCPUOffloadConnector."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.simple_kv_offload.copy_backend import DmaCopyBackend
|
||||
from vllm.v1.simple_kv_offload.cuda_mem_ops import pin_tensor
|
||||
from vllm.v1.simple_kv_offload.metadata import (
|
||||
SimpleCPUOffloadMetadata,
|
||||
SimpleCPUOffloadWorkerMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleCPUOffloadWorker:
|
||||
"""Worker-side handler for CPU offloading transfers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: "KVCacheConfig | None",
|
||||
cpu_capacity_bytes: int,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.cpu_capacity_bytes = cpu_capacity_bytes
|
||||
|
||||
self.gpu_kv_caches: dict[str, torch.Tensor] | None = None
|
||||
self.cpu_kv_caches: dict[str, torch.Tensor] | None = None
|
||||
self.device: torch.device | None = None
|
||||
self.num_cpu_blocks: int = 0
|
||||
|
||||
# CUDA streams for the async transfers
|
||||
self.load_stream: torch.cuda.Stream | None = None
|
||||
self.store_stream: torch.cuda.Stream | None = None
|
||||
|
||||
self._backend = DmaCopyBackend()
|
||||
|
||||
# Ordered (event_idx, Event). Events pre-allocated on main thread.
|
||||
self._load_events: list[tuple[int, torch.Event]] = []
|
||||
self._store_events: list[tuple[int, torch.Event]] = []
|
||||
# High-water marks: highest event_idx completed per stream.
|
||||
# When the event list is empty, the hwm covers all prior events.
|
||||
self._load_hwm: int = -1
|
||||
self._store_hwm: int = -1
|
||||
|
||||
# Metadata for the current step
|
||||
self._connector_metadata: SimpleCPUOffloadMetadata | None = None
|
||||
|
||||
# Pending event index sets, populated in bind_connector_metadata
|
||||
self._pending_load_event_indices: set[int] = set()
|
||||
self._pending_store_event_indices: set[int] = set()
|
||||
# Completed store events to report via build_connector_worker_meta
|
||||
self._completed_store_events: dict[int, int] = {}
|
||||
|
||||
def register_kv_caches(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
) -> None:
|
||||
"""Register GPU KV caches and allocate pinned CPU tensors.
|
||||
The worker will infer the underlying raw storage from the kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: Per-layer GPU KV caches. Values are either a single
|
||||
tensor (attention layers) or a list of tensors (Mamba layers
|
||||
in hybrid models). All values are included for offloading
|
||||
by resolving to their underlying raw storage.
|
||||
"""
|
||||
if not kv_caches:
|
||||
logger.warning("No KV caches to offload.")
|
||||
return
|
||||
|
||||
# Resolve each entry to a representative tensor for storage
|
||||
# deduplication. For attention layers the value is already a tensor;
|
||||
# for Mamba layers it is a list of tensors that all share the same
|
||||
# underlying raw storage, so we take the first one.
|
||||
def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
||||
assert isinstance(v, torch.Tensor | list)
|
||||
return v if isinstance(v, torch.Tensor) else v[0]
|
||||
|
||||
any_tensor = _repr_tensor(next(iter(kv_caches.values())))
|
||||
self.device = any_tensor.device
|
||||
|
||||
assert self.kv_cache_config is not None
|
||||
num_blocks = self.kv_cache_config.num_blocks
|
||||
|
||||
# Deduplicate: multiple layers may share the same backing storage.
|
||||
seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {}
|
||||
for name, value in kv_caches.items():
|
||||
tensor = _repr_tensor(value)
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
if ptr not in seen_ptrs:
|
||||
seen_ptrs[ptr] = (name, tensor)
|
||||
|
||||
# Build [num_blocks, block_bytes] int8 views from each unique
|
||||
# storage so that stride(0) gives block_bytes for the copy op.
|
||||
#
|
||||
# The physical layout varies across attention backends:
|
||||
# FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments
|
||||
# FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment
|
||||
# We derive page_size_bytes = storage.nbytes() // num_blocks, then
|
||||
# classify dims: any dim whose byte-stride exceeds page_size_bytes
|
||||
# must be an outer segment dim (e.g. the K/V dim of size 2). A less
|
||||
# hacky way is to update the interface with the layout.
|
||||
unique_gpu_caches: dict[str, torch.Tensor] = {}
|
||||
for name, tensor in seen_ptrs.values():
|
||||
storage = tensor.untyped_storage()
|
||||
raw = torch.empty(0, dtype=torch.int8, device=self.device).set_(
|
||||
storage, 0, (storage.nbytes(),)
|
||||
)
|
||||
el = tensor.element_size()
|
||||
page_size_bytes = storage.nbytes() // num_blocks
|
||||
outer_dims = [
|
||||
d for d in range(tensor.ndim) if tensor.stride(d) * el > page_size_bytes
|
||||
]
|
||||
if not outer_dims:
|
||||
unique_gpu_caches[name] = raw.view(num_blocks, -1)
|
||||
else:
|
||||
seg_stride = tensor.stride(outer_dims[0]) * el
|
||||
for idx in range(tensor.shape[outer_dims[0]]):
|
||||
offset = idx * seg_stride
|
||||
chunk = raw[offset : offset + seg_stride]
|
||||
unique_gpu_caches[f"{name}.{idx}"] = chunk.view(num_blocks, -1)
|
||||
|
||||
# Compute per-tensor bytes_per_block. Tensors may have different
|
||||
# page_size_bytes (e.g., UniformTypeKVCacheSpecs with varying head_size).
|
||||
per_tensor_bpb = [
|
||||
t.stride(0) * t.element_size() for t in unique_gpu_caches.values()
|
||||
]
|
||||
total_bytes_per_block = sum(per_tensor_bpb)
|
||||
|
||||
self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block)
|
||||
|
||||
logger.info(
|
||||
"SimpleCPUOffloadWorker: %d unique GPU KV tensors, "
|
||||
"allocating %d CPU blocks (%.2f GB)",
|
||||
len(unique_gpu_caches),
|
||||
self.num_cpu_blocks,
|
||||
(self.num_cpu_blocks * total_bytes_per_block) / (1024**3),
|
||||
)
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
if not pin_memory:
|
||||
logger.warning(
|
||||
"Pinned memory not available. CPU offload performance may be degraded."
|
||||
)
|
||||
|
||||
self.gpu_kv_caches = unique_gpu_caches
|
||||
self.cpu_kv_caches = {}
|
||||
for name, gpu_tensor in unique_gpu_caches.items():
|
||||
cpu_shape = (self.num_cpu_blocks,) + gpu_tensor.shape[1:]
|
||||
# Allocate non-pinned first, then pin via cudaHostRegister to
|
||||
# bypass PyTorch's CUDACachingHostAllocator which rounds up to
|
||||
# the next power of 2 (e.g. 100 GB -> 128 GB).
|
||||
tensor = torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu")
|
||||
if pin_memory:
|
||||
pin_tensor(tensor)
|
||||
self.cpu_kv_caches[name] = tensor
|
||||
|
||||
# Use lowest priority so KV cache I/O yields to compute streams.
|
||||
low_pri, _ = torch.cuda.Stream.priority_range()
|
||||
self.load_stream = torch.cuda.Stream(priority=low_pri)
|
||||
self.store_stream = torch.cuda.Stream(priority=low_pri)
|
||||
|
||||
# Initialize copy backend with caches and streams.
|
||||
self._backend.init(
|
||||
self.gpu_kv_caches,
|
||||
self.cpu_kv_caches,
|
||||
self.device,
|
||||
self.load_stream,
|
||||
self.store_stream,
|
||||
)
|
||||
|
||||
def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None:
|
||||
self._connector_metadata = metadata
|
||||
if metadata.load_event >= 0:
|
||||
self._pending_load_event_indices.add(metadata.load_event)
|
||||
if metadata.store_event >= 0:
|
||||
self._pending_store_event_indices.add(metadata.store_event)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self._connector_metadata = None
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
# NOTE: we defer launching both load and store to get_finished(),
|
||||
# which runs after model execution. This hides the CPU-side
|
||||
# block copy op overhead (~5ms) behind GPU compute.
|
||||
pass
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self,
|
||||
finished_req_ids: set[str],
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""Submit transfers and report completed events to the scheduler.
|
||||
|
||||
Called after model execution. The manager only schedules stores for
|
||||
blocks whose KV data is confirmed computed, so we launch both loads
|
||||
and stores immediately — no deferral or cross-stream sync needed.
|
||||
|
||||
Returns:
|
||||
tuple of (finished_sending, finished_recving).
|
||||
- finished_sending: always None (stores use worker metadata).
|
||||
- finished_recving: req_ids whose loads have completed.
|
||||
"""
|
||||
# (1) Submit transfers
|
||||
metadata = self._connector_metadata
|
||||
if metadata is not None:
|
||||
# Launch loads (CPU->GPU).
|
||||
if metadata.load_cpu_blocks:
|
||||
self._backend.launch_copy(
|
||||
metadata.load_cpu_blocks,
|
||||
metadata.load_gpu_blocks,
|
||||
is_store=False,
|
||||
event_idx=metadata.load_event,
|
||||
events_list=self._load_events,
|
||||
)
|
||||
# Launch stores (GPU->CPU).
|
||||
if metadata.store_gpu_blocks:
|
||||
self._backend.launch_copy(
|
||||
metadata.store_gpu_blocks,
|
||||
metadata.store_cpu_blocks,
|
||||
is_store=True,
|
||||
event_idx=metadata.store_event,
|
||||
events_list=self._store_events,
|
||||
)
|
||||
|
||||
# (2) Track completed transfer events
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
if self._pending_load_event_indices:
|
||||
load_wm = self._poll_stream_events(is_store=False)
|
||||
for j in [j for j in self._pending_load_event_indices if j <= load_wm]:
|
||||
self._pending_load_event_indices.discard(j)
|
||||
req_ids = (
|
||||
metadata.load_event_to_reqs.get(j) if metadata is not None else None
|
||||
)
|
||||
if req_ids:
|
||||
finished_recving.update(req_ids)
|
||||
|
||||
if self._pending_store_event_indices:
|
||||
store_wm = self._poll_stream_events(is_store=True)
|
||||
for j in [j for j in self._pending_store_event_indices if j <= store_wm]:
|
||||
self._pending_store_event_indices.discard(j)
|
||||
self._completed_store_events[j] = 1
|
||||
|
||||
return None, finished_recving or None
|
||||
|
||||
def build_connector_worker_meta(self) -> SimpleCPUOffloadWorkerMetadata | None:
|
||||
"""Return completed store events since the last call."""
|
||||
if not self._completed_store_events:
|
||||
return None
|
||||
meta = SimpleCPUOffloadWorkerMetadata(
|
||||
completed_store_events=self._completed_store_events,
|
||||
)
|
||||
self._completed_store_events = {}
|
||||
return meta
|
||||
|
||||
def handle_preemptions(
|
||||
self, kv_connector_metadata: SimpleCPUOffloadMetadata
|
||||
) -> None:
|
||||
"""Sync all in-flight transfers before preempted blocks are reused."""
|
||||
if not kv_connector_metadata.need_flush:
|
||||
return
|
||||
self._flush_and_sync_all()
|
||||
|
||||
def _flush_and_sync_all(self) -> None:
|
||||
"""Synchronize all in-flight transfer events."""
|
||||
for event_idx, event in self._load_events:
|
||||
event.synchronize()
|
||||
self._load_hwm = event_idx
|
||||
self._load_events.clear()
|
||||
|
||||
for event_idx, event in self._store_events:
|
||||
event.synchronize()
|
||||
self._store_hwm = event_idx
|
||||
self._store_events.clear()
|
||||
|
||||
def _poll_stream_events(self, is_store: bool) -> int:
|
||||
"""Non-blocking poll for completed events and return the high-water mark."""
|
||||
events = self._store_events if is_store else self._load_events
|
||||
hwm = self._store_hwm if is_store else self._load_hwm
|
||||
while events:
|
||||
event_idx, event = events[0]
|
||||
if not event.query():
|
||||
break
|
||||
hwm = event_idx
|
||||
events.pop(0)
|
||||
if is_store:
|
||||
self._store_hwm = hwm
|
||||
else:
|
||||
self._load_hwm = hwm
|
||||
return hwm
|
||||
@@ -818,7 +818,6 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
@@ -833,7 +832,7 @@ class SpecDecodeBaseProposer:
|
||||
"""
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
|
||||
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
|
||||
|
||||
@@ -286,7 +286,6 @@ class ExtractHiddenStatesProposer:
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
@@ -303,7 +302,7 @@ class ExtractHiddenStatesProposer:
|
||||
device = sampled_token_ids.device
|
||||
|
||||
# Compute backup tokens for discarded / invalid requests
|
||||
seq_lens_list = seq_lens[:num_reqs].tolist()
|
||||
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
|
||||
backup_tokens_gpu = torch.tensor(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
|
||||
|
||||
@@ -108,6 +108,15 @@ class CPUWorker(Worker):
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# After the thread binding, changing thread num is not allowed
|
||||
def skip_set_num_threads(x: int):
|
||||
logger.warning(
|
||||
"CPU backend doesn't allow to use "
|
||||
"`torch.set_num_threads` after the thread binding, skip it."
|
||||
)
|
||||
|
||||
torch.set_num_threads = skip_set_num_threads
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
|
||||
# Initialize the distributed environment.
|
||||
|
||||
@@ -208,7 +208,7 @@ from .utils import (
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1933,9 +1933,24 @@ class GPUModelRunner(
|
||||
# _update_states_after_model_execute for hybrid models).
|
||||
if self.num_accepted_tokens_event is not None:
|
||||
self.num_accepted_tokens_event.synchronize()
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
# Async mode: condense() reordered indices, use prev_positions mapping
|
||||
if self.use_async_scheduling and prev_req_id_to_index:
|
||||
prev_idx = self.prev_positions.np[:num_reqs]
|
||||
new_mask = prev_idx < 0
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[
|
||||
np.where(new_mask, 0, prev_idx)
|
||||
]
|
||||
)
|
||||
self.num_accepted_tokens.np[:num_reqs][new_mask] = 1
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs] = (
|
||||
self.num_accepted_tokens.np[:num_reqs]
|
||||
)
|
||||
else:
|
||||
# Non-async mode: use values directly
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
else:
|
||||
@@ -4211,7 +4226,6 @@ class GPUModelRunner(
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -4578,7 +4592,6 @@ class GPUModelRunner(
|
||||
)
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -4617,7 +4630,6 @@ class GPUModelRunner(
|
||||
)
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -5969,7 +5981,9 @@ class GPUModelRunner(
|
||||
SupportsEncoderCudaGraph,
|
||||
supports_encoder_cudagraph,
|
||||
)
|
||||
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
|
||||
EncoderCudaGraphManager,
|
||||
)
|
||||
|
||||
raw_model = self.get_model()
|
||||
if supports_encoder_cudagraph(raw_model):
|
||||
|
||||
Reference in New Issue
Block a user