diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1f889c6c8..72980a85a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -457,9 +457,11 @@ class TpKVTopology: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. + When they are equal, each TP rank still owns one distinct KV head, + so this is not considered replication. """ tp_size = self.remote_tp_size[engine_id] - return tp_size // self.total_num_kv_heads >= 1 + return tp_size > self.total_num_kv_heads def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: # MLA is always replicated as the hidden dim can't be split.