[v1] torchrun compatibility (#13642)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@@ -185,9 +185,8 @@ class Worker(WorkerBase):
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
kv_cache_config = kv_cache_configs[self.rank]
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
@@ -225,7 +224,7 @@ class Worker(WorkerBase):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.rank == 0 else None
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
|
||||
Reference in New Issue
Block a user