[Model] enable data parallel for Llama4 vision encoder (#18368)

Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Co-authored-by: yZhen <yZhen@fb.com>
Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
jennyyyyzhen
2025-06-02 04:22:54 -07:00
committed by GitHub
parent 5b168b6d7a
commit ebb1ec9318
4 changed files with 214 additions and 68 deletions

View File

@@ -423,6 +423,9 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
@@ -637,6 +640,9 @@ class EngineArgs:
**parallel_kwargs["worker_cls"])
parallel_group.add_argument("--worker-extension-cls",
**parallel_kwargs["worker_extension_cls"])
parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
@@ -1078,6 +1084,8 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
)
speculative_config = self.create_speculative_config(