[LMCache] Pass TP size in lookup for MLA multi-reader locking (#36129)

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Co-authored-by: Yihua Cheng <yihua98@uchicago.edu>
This commit is contained in:
maobaolong
2026-03-12 04:45:20 +08:00
committed by GitHub
parent 7ee5d5093b
commit 12001f2ebc
2 changed files with 22 additions and 0 deletions

View File

@@ -114,6 +114,7 @@ class LMCacheMPSchedulerAdapter:
world_size: int,
kv_rank: int,
vllm_block_size: int,
tp_size: int = 1,
):
"""
Args:
@@ -124,6 +125,8 @@ class LMCacheMPSchedulerAdapter:
world_size: The world size used for LMCache keys
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
tp_size: Tensor-parallel size for MLA
multi-reader locking (default 1).
"""
self.mq_client = MessageQueueClient(server_url, context)
@@ -133,6 +136,7 @@ class LMCacheMPSchedulerAdapter:
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
self.tp_size = tp_size
# Read chunk size from lmcache
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
@@ -281,6 +285,7 @@ class LMCacheMPSchedulerAdapter:
start=start,
end=end,
request_id=request_id,
tp_size=self.tp_size,
)
def _create_hash_key(
@@ -293,6 +298,7 @@ class LMCacheMPSchedulerAdapter:
worker_id=None,
chunk_hash=chunk_hash,
request_id=request_id,
tp_size=self.tp_size,
)

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import inspect
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
@@ -52,6 +53,12 @@ if TYPE_CHECKING:
logger = lmcache_init_logger(__name__)
def _adapter_accepts_tp_size() -> bool:
"""Check if the imported adapter accepts tp_size."""
sig = inspect.signature(LMCacheMPSchedulerAdapter.__init__)
return "tp_size" in sig.parameters
# Helper functions
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
if block_ids is None:
@@ -101,6 +108,14 @@ def create_scheduler_adapter(
vllm_config.parallel_config.rank,
vllm_config,
)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Pass tp_size only when the adapter accepts it so that
# a newer vllm can still work with an older LMCache.
kwargs: dict[str, Any] = {}
if _adapter_accepts_tp_size():
kwargs["tp_size"] = tp_size
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
@@ -108,6 +123,7 @@ def create_scheduler_adapter(
world_size,
kv_rank,
vllm_config.cache_config.block_size,
**kwargs,
)