Remove hardcoded device="cuda" to support more devices (#2503)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -6,8 +6,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
@@ -53,6 +53,7 @@ class LLMEngine:
|
||||
management.
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
device_config: The configuration related to the device.
|
||||
placement_group: Ray placement group for distributed execution.
|
||||
Required for distributed execution.
|
||||
log_stats: Whether to log statistics.
|
||||
@@ -64,6 +65,7 @@ class LLMEngine:
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
@@ -85,6 +87,7 @@ class LLMEngine:
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||
f"device_config={device_config.device}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@@ -93,6 +96,7 @@ class LLMEngine:
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
@@ -138,6 +142,7 @@ class LLMEngine:
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@@ -233,6 +238,7 @@ class LLMEngine:
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
@@ -244,6 +250,7 @@ class LLMEngine:
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
@@ -257,6 +264,7 @@ class LLMEngine:
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
|
||||
Reference in New Issue
Block a user