[Model] enable data parallel for Llama4 vision encoder (#18368)
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
@@ -423,6 +423,9 @@ class EngineArgs:
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = \
|
||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
@@ -637,6 +640,9 @@ class EngineArgs:
|
||||
**parallel_kwargs["worker_cls"])
|
||||
parallel_group.add_argument("--worker-extension-cls",
|
||||
**parallel_kwargs["worker_extension_cls"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-multimodal-encoder-data-parallel",
|
||||
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
|
||||
|
||||
# KV cache arguments
|
||||
cache_kwargs = get_kwargs(CacheConfig)
|
||||
@@ -1078,6 +1084,8 @@ class EngineArgs:
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
enable_multimodal_encoder_data_parallel=self.
|
||||
enable_multimodal_encoder_data_parallel,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
|
||||
Reference in New Issue
Block a user