[LMCache] Pass TP size in lookup for MLA multi-reader locking (#36129)
Signed-off-by: baoloongmao <baoloongmao@tencent.com> Co-authored-by: Yihua Cheng <yihua98@uchicago.edu>
This commit is contained in:
@@ -114,6 +114,7 @@ class LMCacheMPSchedulerAdapter:
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
tp_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -124,6 +125,8 @@ class LMCacheMPSchedulerAdapter:
|
||||
world_size: The world size used for LMCache keys
|
||||
kv_rank: The kv rank used for LMCache keys
|
||||
vllm_block_size: The block size used in vLLM
|
||||
tp_size: Tensor-parallel size for MLA
|
||||
multi-reader locking (default 1).
|
||||
"""
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
@@ -133,6 +136,7 @@ class LMCacheMPSchedulerAdapter:
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
self.tp_size = tp_size
|
||||
|
||||
# Read chunk size from lmcache
|
||||
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
@@ -281,6 +285,7 @@ class LMCacheMPSchedulerAdapter:
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
|
||||
def _create_hash_key(
|
||||
@@ -293,6 +298,7 @@ class LMCacheMPSchedulerAdapter:
|
||||
worker_id=None,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
@@ -52,6 +53,12 @@ if TYPE_CHECKING:
|
||||
logger = lmcache_init_logger(__name__)
|
||||
|
||||
|
||||
def _adapter_accepts_tp_size() -> bool:
|
||||
"""Check if the imported adapter accepts tp_size."""
|
||||
sig = inspect.signature(LMCacheMPSchedulerAdapter.__init__)
|
||||
return "tp_size" in sig.parameters
|
||||
|
||||
|
||||
# Helper functions
|
||||
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
|
||||
if block_ids is None:
|
||||
@@ -101,6 +108,14 @@ def create_scheduler_adapter(
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
# Pass tp_size only when the adapter accepts it so that
|
||||
# a newer vllm can still work with an older LMCache.
|
||||
kwargs: dict[str, Any] = {}
|
||||
if _adapter_accepts_tp_size():
|
||||
kwargs["tp_size"] = tp_size
|
||||
|
||||
return LMCacheMPSchedulerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
@@ -108,6 +123,7 @@ def create_scheduler_adapter(
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user