Remove hardcoded device="cuda" to support more devices (#2503)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -84,7 +84,7 @@ def create_worker(cls: type,
|
||||
)
|
||||
|
||||
(model_config, cache_config, parallel_config, scheduler_config,
|
||||
_) = engine_args.create_engine_configs()
|
||||
device_config, _) = engine_args.create_engine_configs()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
@@ -93,6 +93,7 @@ def create_worker(cls: type,
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
Reference in New Issue
Block a user