From 35bdca5431e652b4c00267489a632c1bf5522103 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:40:17 -0400 Subject: [PATCH] [Refactor] Remove dead code in KV connector (#36424) Signed-off-by: yewentao256 --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 8 +------- vllm/v1/core/sched/scheduler.py | 8 +++----- vllm/v1/engine/core.py | 4 +--- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index cc16dee82..e6c49d7a0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -50,7 +50,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger @@ -564,7 +563,6 @@ class NixlConnectorScheduler: # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: threading.Thread | None = None - self._encoded_xfer_handshake_metadata: dict[int, Any] = {} self._stop_event = threading.Event() # Requests that need to start recv/send. @@ -650,7 +648,6 @@ class NixlConnectorScheduler: tp_rank, str(len(encoded_data[tp_rank])), ) - self._encoded_xfer_handshake_metadata = encoded_data # Only start the listener when we have metadata to serve. if self._nixl_handshake_listener_t is None: @@ -995,7 +992,7 @@ class NixlConnectorWorker: self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() - self.tp_group = get_tp_group() + self.num_blocks = kv_cache_config.num_blocks self.enable_permute_local_kv = False @@ -1064,7 +1061,6 @@ class NixlConnectorWorker: # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - self.num_layers = 0 # nixl_prepped_dlist_handle. self.src_xfer_handles_by_block_size: dict[int, int] = {} @@ -1108,7 +1104,6 @@ class NixlConnectorWorker: self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config self.use_mla = self.model_config.use_mla @@ -1540,7 +1535,6 @@ class NixlConnectorWorker: self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) - self.num_layers = len(xfer_buffers.keys()) descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 61418692b..ea2c2a6cd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -184,13 +184,11 @@ class Scheduler(SchedulerInterface): # Encoder-related. # Calculate encoder cache size if applicable - self.supports_mm_inputs = mm_registry.supports_multimodal_inputs( + supports_mm_inputs = mm_registry.supports_multimodal_inputs( vllm_config.model_config ) - self.mm_budget = mm_budget = ( - MultiModalBudget(vllm_config, mm_registry) - if self.supports_mm_inputs - else None + mm_budget = ( + MultiModalBudget(vllm_config, mm_registry) if supports_mm_inputs else None ) # NOTE: Text-only encoder-decoder models are implemented as diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 50c116f85..3d315086f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -148,7 +148,7 @@ class EngineCore: if self.scheduler.connector is not None: # type: ignore self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore - self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + mm_registry = MULTIMODAL_REGISTRY self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config( vllm_config ) @@ -800,8 +800,6 @@ class EngineCoreProc(EngineCore): vllm_config, client_handshake_address, ) as addresses: - self.client_count = len(addresses.outputs) - # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = (