diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index db1d34ca1..38dd980c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -336,11 +336,21 @@ class LMCacheMPRequestMetadata: start_token_idx = start * vllm_block_size end_token_idx = end * vllm_block_size token_ids = list(tracker.all_token_ids) + + # Compute how many tokens at the start of the retrieve range + # overlap with APC-shared blocks. The server must skip writing + # to these positions to avoid a cross-stream data race: the + # retrieve writes on the LMCache CUDA stream while concurrent + # requests may read these APC-shared blocks on the vLLM stream. + apc_overlap_blocks = tracker.num_vllm_hit_blocks - start + skip_first_n_tokens = apc_overlap_blocks * vllm_block_size + op = LoadStoreOp( token_ids=token_ids, block_ids=block_ids, start=start_token_idx, end=end_token_idx, + skip_first_n_tokens=skip_first_n_tokens, ) ret = LMCacheMPRequestMetadata( @@ -700,13 +710,22 @@ class LMCacheMPConnector(KVConnectorBase_V1): num_external_tokens (int): the number of tokens that will be loaded from the external KV cache. """ - # NOTE: the `blocks` are NEW BLOCKS allocated for this request. + # NOTE: `blocks` comes from kv_cache_manager.get_blocks(request_id), + # which returns ALL blocks for the request (not just newly allocated). + # This function may be called twice for async-load requests: + # 1st call: blocks = initial allocation (APC + fresh) + # 2nd call: blocks = all blocks + # (initial + newly allocated for remaining tokens) + # We must only append the NEW blocks beyond what's already tracked + # to avoid duplication, which would corrupt the store path's block indexing. tracker = self._get_request_tracker(request.request_id) block_ids = reformat_block_ids(blocks.get_block_ids()) - # No matter we need to retrieve or not, we need to update - # the block ids into the tracker - tracker.append_block_ids(block_ids) + # Only append blocks beyond what's already tracked + existing_count = len(tracker.allocated_block_ids) + new_block_ids = block_ids[existing_count:] + if new_block_ids: + tracker.append_block_ids(new_block_ids) # Update the state of the tracker condition = tracker.needs_retrieve()