[VLM][Model] TP support for ViTs (#7186)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Jungho Christopher Cho
2024-08-31 00:19:27 +09:00
committed by GitHub
parent afd39a4511
commit f97be32d1d
9 changed files with 336 additions and 285 deletions

View File

@@ -714,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
# BlipVisionModel does not need sharding
use_default_weight_loading = True
else:
for (param_name, weight_name,