Remove hardcoded device="cuda" to support more devices (#2503)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2024-02-02 07:46:39 +08:00
committed by GitHub
parent c410f5d020
commit 96b6f475dd
32 changed files with 343 additions and 292 deletions

View File

@@ -6,8 +6,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
@@ -53,6 +53,7 @@ class LLMEngine:
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
@@ -64,6 +65,7 @@ class LLMEngine:
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"],
log_stats: bool,
@@ -85,6 +87,7 @@ class LLMEngine:
f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"device_config={device_config.device}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
@@ -93,6 +96,7 @@ class LLMEngine:
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.log_stats = log_stats
self._verify_args()
@@ -138,6 +142,7 @@ class LLMEngine:
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
@@ -233,6 +238,7 @@ class LLMEngine:
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers,
@@ -244,6 +250,7 @@ class LLMEngine:
model_config,
parallel_config,
scheduler_config,
device_config,
local_rank,
rank,
distributed_init_method,
@@ -257,6 +264,7 @@ class LLMEngine:
model_config,
parallel_config,
scheduler_config,
device_config,
driver_local_rank,
driver_rank,
distributed_init_method,