Remove hardcoded device="cuda" to support more devices (#2503)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -3,8 +3,8 @@ import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -43,6 +43,7 @@ class EngineArgs:
|
||||
lora_extra_vocab_size: int = 256
|
||||
lora_dtype = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
device: str = 'cuda'
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@@ -127,13 +128,13 @@ class EngineArgs:
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
default=EngineArgs.kv_cache_dtype,
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. Note FP8 is not supported when cuda version is '
|
||||
'lower than 11.8.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
default=EngineArgs.max_model_len,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
# Parallel arguments
|
||||
@@ -154,6 +155,7 @@ class EngineArgs:
|
||||
parser.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
default=EngineArgs.max_parallel_loading_workers,
|
||||
help='load model sequentially in multiple batches, '
|
||||
'to avoid RAM OOM when using tensor '
|
||||
'parallel and large models')
|
||||
@@ -200,7 +202,7 @@ class EngineArgs:
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None,
|
||||
default=EngineArgs.quantization,
|
||||
help='Method used to quantize the weights. If '
|
||||
'None, we first check the `quantization_config` '
|
||||
'attribute in the model config file. If that is '
|
||||
@@ -255,6 +257,13 @@ class EngineArgs:
|
||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||
'Must be >= than max_num_seqs. '
|
||||
'Defaults to max_num_seqs.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
choices=["cuda"],
|
||||
help=('Device type for vLLM execution. '
|
||||
'Currently, only CUDA-compatible devices are supported.'))
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -268,7 +277,8 @@ class EngineArgs:
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||
Optional[LoRAConfig]]:
|
||||
DeviceConfig, Optional[LoRAConfig]]:
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
@@ -296,7 +306,8 @@ class EngineArgs:
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
return model_config, cache_config, parallel_config, scheduler_config, lora_config
|
||||
return (model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, lora_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user