[Model] enable data parallel for Llama4 vision encoder (#18368)

Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Co-authored-by: yZhen <yZhen@fb.com>
Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
jennyyyyzhen
2025-06-02 04:22:54 -07:00
committed by GitHub
parent 5b168b6d7a
commit ebb1ec9318
4 changed files with 214 additions and 68 deletions

View File

@@ -1790,6 +1790,10 @@ class ParallelConfig:
rank: int = 0
"""Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel: bool = False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world