[Refactor] Remove dead code in kv connector and model runner (#38383)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( #
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_initialized,
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
)
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
@@ -57,4 +58,4 @@ def test_kv_connector_mixin_clears_metadata():
|
||||
assert connector.call_record.get("clear_connector_metadata", 0) == 1
|
||||
finally:
|
||||
# Ensure we clean up the global connector between tests
|
||||
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
@@ -102,10 +102,6 @@ class CPUModelRunner(GPUModelRunner):
|
||||
# so stale KV cache data never affects computation.
|
||||
pass
|
||||
|
||||
def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
|
||||
# Note: For CPU backend, dp padding is not required for now.
|
||||
return 0, None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
|
||||
@@ -34,14 +34,6 @@ class ECConnectorModelRunnerMixin:
|
||||
connector = get_ec_transfer()
|
||||
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
|
||||
|
||||
@staticmethod
|
||||
def get_finished_ec_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if has_ec_transfer():
|
||||
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
|
||||
@@ -97,11 +97,6 @@ class ActiveKVConnector(KVConnector):
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
return output
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clear the connector metadata. Call this after draft model runs."""
|
||||
if not self._disabled:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
if self._disabled:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
@@ -451,7 +451,6 @@ class GPUModelRunner(
|
||||
# Model-related.
|
||||
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
|
||||
self.inputs_embeds_size = model_config.get_inputs_embeds_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
self.use_alibi = model_config.uses_alibi
|
||||
|
||||
@@ -594,7 +593,6 @@ class GPUModelRunner(
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
|
||||
@@ -13,11 +13,7 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@@ -38,12 +34,6 @@ logger = init_logger(__name__)
|
||||
|
||||
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class KVConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
@staticmethod
|
||||
def kv_connector_no_forward(
|
||||
scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import supports_xpu_graph
|
||||
from vllm.v1.worker.gpu.model_runner import (
|
||||
GPUModelRunner as GPUModelRunnerV2,
|
||||
)
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUModelRunner(GPUModelRunner):
|
||||
"""A model runner for XPU devices."""
|
||||
@@ -47,19 +40,16 @@ class XPUModelRunnerV2(GPUModelRunnerV2):
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
try:
|
||||
# replace cuda APIs with xpu APIs, this should work by default
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch.cuda.Event = torch.Event
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
if supports_xpu_graph():
|
||||
torch.cuda.graph = torch.xpu.graph
|
||||
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
|
||||
torch.cuda.graph_pool_handle = torch.xpu.graph_pool_handle
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
# replace cuda APIs with xpu APIs, this should work by default
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch.cuda.Event = torch.Event
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
if supports_xpu_graph():
|
||||
torch.cuda.graph = torch.xpu.graph
|
||||
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
|
||||
torch.cuda.graph_pool_handle = torch.xpu.graph_pool_handle
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user