Remove hardcoded device="cuda" to support more devices (#2503)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2024-02-02 07:46:39 +08:00
committed by GitHub
parent c410f5d020
commit 96b6f475dd
32 changed files with 343 additions and 292 deletions

View File

@@ -84,7 +84,7 @@ def create_worker(cls: type,
)
(model_config, cache_config, parallel_config, scheduler_config,
_) = engine_args.create_engine_configs()
device_config, _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
@@ -93,6 +93,7 @@ def create_worker(cls: type,
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,