[BugFix][kv_offload] Fix offloading decodes with async scheduling (#33881)
Signed-off-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -148,17 +148,23 @@ class TransferSummary:
|
||||
|
||||
class RequestRunner:
|
||||
def __init__(
|
||||
self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int
|
||||
self,
|
||||
offloaded_block_size: int,
|
||||
gpu_block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
async_scheduling: bool = True,
|
||||
):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
self.async_scheduling: bool = async_scheduling
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
block_size=gpu_block_size, max_num_batched_tokens=1000
|
||||
)
|
||||
vllm_config.scheduler_config.async_scheduling = async_scheduling
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
@@ -313,6 +319,8 @@ class RequestRunner:
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
prev_scheduler_output = None
|
||||
prev_model_runner_output = None
|
||||
while True:
|
||||
assert self.scheduler.requests
|
||||
|
||||
@@ -354,7 +362,16 @@ class RequestRunner:
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
if self.async_scheduling:
|
||||
# in async scheduling we update the output of the previous step
|
||||
if prev_model_runner_output is not None:
|
||||
self.scheduler.update_from_output(
|
||||
prev_scheduler_output, prev_model_runner_output
|
||||
)
|
||||
prev_scheduler_output = scheduler_output
|
||||
prev_model_runner_output = model_runner_output
|
||||
else:
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
if (
|
||||
prev_token_id == EOS_TOKEN_ID
|
||||
@@ -365,6 +382,11 @@ class RequestRunner:
|
||||
continue
|
||||
|
||||
if token_id is None:
|
||||
if self.async_scheduling:
|
||||
# sample last token
|
||||
self.scheduler.update_from_output(
|
||||
prev_scheduler_output, prev_model_runner_output
|
||||
)
|
||||
break
|
||||
|
||||
self._parse_transfers()
|
||||
@@ -445,11 +467,14 @@ class RequestRunner:
|
||||
def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
def runner_factory(
|
||||
offloaded_block_size, gpu_block_size, num_gpu_blocks, async_scheduling
|
||||
):
|
||||
runner = RequestRunner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
@@ -466,7 +491,8 @@ def generate_store_output(block_hashes: Iterable[BlockHash]):
|
||||
)
|
||||
|
||||
|
||||
def test_offloading_connector(request_runner):
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_offloading_connector(request_runner, async_scheduling: bool):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
@@ -476,6 +502,7 @@ def test_offloading_connector(request_runner):
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
@@ -498,26 +525,28 @@ def test_offloading_connector(request_runner):
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
# 1 more block (+ token for async scheduling)
|
||||
# now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size + 1))
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
# 1 more block (+ token for kicking off offloading)
|
||||
# now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * (offloaded_block_size + 1),
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
@@ -608,7 +637,8 @@ def test_offloading_connector(request_runner):
|
||||
assert event.medium == "B"
|
||||
|
||||
|
||||
def test_request_preemption(request_runner):
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_request_preemption(request_runner, async_scheduling: bool):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
@@ -617,6 +647,7 @@ def test_request_preemption(request_runner):
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
|
||||
free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
|
||||
@@ -674,7 +705,8 @@ def test_request_preemption(request_runner):
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_lookups_of_the_same_prefix(request_runner):
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: bool):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
@@ -683,6 +715,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner):
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
|
||||
# store 1 blocks
|
||||
@@ -732,7 +765,8 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner):
|
||||
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
|
||||
|
||||
|
||||
def test_abort_loading_requests(request_runner):
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_abort_loading_requests(request_runner, async_scheduling: bool):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
@@ -741,6 +775,7 @@ def test_abort_loading_requests(request_runner):
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
|
||||
# store 1 blocks
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( #
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
@@ -143,7 +144,7 @@ def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
) -> Scheduler:
|
||||
) -> Scheduler | AsyncScheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
if kv_cache_config is None:
|
||||
@@ -163,7 +164,11 @@ def create_scheduler(
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
|
||||
scheduler_cls = (
|
||||
AsyncScheduler if vllm_config.scheduler_config.async_scheduling else Scheduler
|
||||
)
|
||||
return scheduler_cls(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
|
||||
@@ -416,7 +416,9 @@ class OffloadingConnectorScheduler:
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
expected_tokens = req.num_computed_tokens + new_tokens
|
||||
# with async scheduling, some tokens may be missing
|
||||
total_tokens = min(expected_tokens, req.num_tokens)
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
@@ -424,8 +426,8 @@ class OffloadingConnectorScheduler:
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
# NOTE: In async scheduling, placeholders may temporarily make
|
||||
# len(req.block_hashes) < num_blocks * self.block_size_factor.
|
||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||
assert len(req.block_hashes) >= num_gpu_blocks
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
@@ -529,6 +531,9 @@ class OffloadingConnectorScheduler:
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
|
||||
# TODO(orozery): possibly kickoff offload for last block
|
||||
# which may have been deferred due to async scheduling
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
|
||||
Reference in New Issue
Block a user