Implicit language-model-only mode via limit-mm-per-prompt (#22299)
Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Andy Xie <andy.xning@gmail.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: XIn Li <xinli@nvidia.com> Signed-off-by: Junhao Li <junhao@ubicloud.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com> Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Signed-off-by: Linkun <github@lkchen.net> Co-authored-by: Ning Xie <andy.xning@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Zhiyu <zhiyuc@nvidia.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: XIn Li <xinli@nvidia.com> Co-authored-by: Junhao Li <streaver91@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: ZiTian Zhao <zitian.zhao@tencentmusic.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Po-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com> Co-authored-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Hong Hanh <hanh.usth@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: lkchen <github@lkchen.net>
This commit is contained in:
@@ -737,16 +737,20 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
else:
|
||||
self.vision_model = None
|
||||
self.multi_modal_projector = None
|
||||
self.language_model = initialize_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config,
|
||||
["LlamaForCausalLM"]),
|
||||
@@ -783,6 +787,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
|
||||
|
||||
assert self.vision_model and self.multi_modal_projector
|
||||
flat_data = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
|
||||
@@ -1048,6 +1054,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
language_model_weights, other_weights = (
|
||||
self._separate_and_rename_weights(weights))
|
||||
|
||||
# Skip loading vision model and projector if they're not initialized.
|
||||
if self.vision_model is None and self.multi_modal_projector is None:
|
||||
other_weights = []
|
||||
|
||||
# Handle expert scale parameters
|
||||
regular_weights, expert_scale_weights, updated_params_from_experts = (
|
||||
self._handle_expert_scale_broadcasting(language_model_weights,
|
||||
|
||||
Reference in New Issue
Block a user