[4/N] Initialize MM components in context managers (M-P) (#32663)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -621,14 +621,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if inputs_embeds is None:
|
||||
multimodal_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids,
|
||||
|
||||
@@ -791,14 +791,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@@ -71,8 +71,6 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
|
||||
|
||||
|
||||
class LMMissingLayer(nn.Module):
|
||||
packed_modules_mapping: dict[str, list[str]] = {}
|
||||
|
||||
def make_empty_intermediate_tensors(self, *args, **kwargs):
|
||||
raise RuntimeError("This module should not be called in MM encoder-only mode")
|
||||
|
||||
@@ -81,8 +79,6 @@ class LMMissingLayer(nn.Module):
|
||||
|
||||
|
||||
class TowerMissingLayer(nn.Module):
|
||||
packed_modules_mapping: dict[str, list[str]] = {}
|
||||
|
||||
def __init__(self, modalities: set[str] | str) -> None:
|
||||
if isinstance(modalities, str):
|
||||
modalities = {modalities}
|
||||
@@ -92,7 +88,10 @@ class TowerMissingLayer(nn.Module):
|
||||
self.modalities = modalities
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"The following modalities are disabled: {self.modalities}")
|
||||
raise RuntimeError(
|
||||
f"This module should not be called when the following "
|
||||
f"modalities are disabled: {self.modalities}"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -789,7 +789,6 @@ class InternS1ForConditionalGeneration(
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@@ -1379,7 +1379,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@@ -707,30 +707,30 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
# Initialize audio components
|
||||
self.audio_encoder = DashengAudioTransformer(
|
||||
config.audio_encoder_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "audio_encoder"),
|
||||
)
|
||||
self.audio_projector = AudioProjectorSubsample(
|
||||
in_dim=config.audio_encoder_config.embed_dim,
|
||||
out_dim=config.text_config.hidden_size,
|
||||
downsample_rate=config.subsample_factor,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "audio_projector"),
|
||||
)
|
||||
|
||||
# Initialize language model (decoder)
|
||||
self.decoder = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "decoder"),
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.audio_encoder = DashengAudioTransformer(
|
||||
config.audio_encoder_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "audio_encoder"),
|
||||
)
|
||||
self.audio_projector = AudioProjectorSubsample(
|
||||
in_dim=config.audio_encoder_config.embed_dim,
|
||||
out_dim=config.text_config.hidden_size,
|
||||
downsample_rate=config.subsample_factor,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "audio_projector"),
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.decoder = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "decoder"),
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.decoder.make_empty_intermediate_tensors
|
||||
)
|
||||
@@ -787,9 +787,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return torch.split(masked_audio_features, audio_output_lengths.tolist())
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.decoder
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
|
||||
@@ -553,9 +553,11 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.apm = self.init_audio_module(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.apm = self.init_audio_module(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
|
||||
)
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
|
||||
@@ -1028,25 +1028,27 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm")
|
||||
)
|
||||
self.vpm = self.init_vision_module(
|
||||
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
|
||||
)
|
||||
self.vision_dim = (
|
||||
self.vpm.embed_dim
|
||||
if self.version == (2, 0)
|
||||
else self.vpm.embeddings.embed_dim
|
||||
)
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
self.resampler = self.init_resampler(
|
||||
self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "resampler"),
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.llm = self.init_llm(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm")
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vpm = vpm = self.init_vision_module(
|
||||
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
|
||||
)
|
||||
self.vision_dim = (
|
||||
vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim
|
||||
)
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
self.resampler = self.init_resampler(
|
||||
self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "resampler"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
|
||||
|
||||
@@ -1134,9 +1136,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
|
||||
@@ -201,28 +201,33 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size)
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.pad_token_id = -1
|
||||
@@ -233,9 +238,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
|
||||
@@ -302,8 +304,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
self,
|
||||
inputs: MiniMaxVL01ImagePixelInputs,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||
|
||||
@@ -314,7 +314,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
if isinstance(image_features, torch.Tensor):
|
||||
@@ -369,14 +368,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
|
||||
@@ -1441,15 +1441,20 @@ class MolmoForCausalLM(
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
vision_config = VisionBackboneConfig()
|
||||
self.vision_backbone = MolmoVisionBackbone(
|
||||
config,
|
||||
vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_backbone"),
|
||||
)
|
||||
self.model = MolmoModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_backbone = MolmoVisionBackbone(
|
||||
config,
|
||||
vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_backbone"),
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.model = MolmoModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
self.img_patch_id = None
|
||||
|
||||
if self.config.weight_tying:
|
||||
@@ -1525,9 +1530,6 @@ class MolmoForCausalLM(
|
||||
results.append(feats[is_valid][order])
|
||||
return results
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -2514,16 +2514,19 @@ class Molmo2ForConditionalGeneration(
|
||||
kwargs[field.name] = getattr(config.adapter_config, field.name)
|
||||
adapter_config = AdapterConfig(**kwargs)
|
||||
|
||||
self.vision_backbone = Molmo2VisionBackbone(
|
||||
vit_config,
|
||||
adapter_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_backbone"),
|
||||
)
|
||||
self.model = Molmo2TextModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vision_backbone = Molmo2VisionBackbone(
|
||||
vit_config,
|
||||
adapter_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_backbone"),
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.model = Molmo2TextModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
|
||||
self.img_patch_id = config.image_patch_id
|
||||
|
||||
@@ -2687,9 +2690,6 @@ class Molmo2ForConditionalGeneration(
|
||||
out.append(out_features)
|
||||
return tuple(out)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
|
||||
@@ -1511,34 +1511,38 @@ class NemotronH_Nano_VL_V2(
|
||||
self.ps_version = config.ps_version
|
||||
self.image_tag_type = config.image_tag_type
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.dtype
|
||||
)
|
||||
|
||||
# Construct the vision projection.
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
RMSNorm(
|
||||
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
eps=1e-5,
|
||||
),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
vision_projection_hidden_size,
|
||||
bias=False,
|
||||
),
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.dtype
|
||||
)
|
||||
|
||||
# Construct the vision projection.
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
mlp1 = nn.Sequential(
|
||||
RMSNorm(
|
||||
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
eps=1e-5,
|
||||
),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
vision_projection_hidden_size,
|
||||
bias=False,
|
||||
),
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = mlp1.to(language_model.config.dtype)
|
||||
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
@@ -1909,9 +1913,6 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -1921,7 +1922,6 @@ class NemotronH_Nano_VL_V2(
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
|
||||
@@ -820,16 +820,18 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.encoder = RadioWithNeck(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.encoder"
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.encoder = RadioWithNeck(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.encoder"
|
||||
)
|
||||
|
||||
self.decoder = MBartDecoderNoPos(
|
||||
config.decoder,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.decoder",
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.decoder = MBartDecoderNoPos(
|
||||
config.decoder,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.decoder",
|
||||
)
|
||||
|
||||
self.vocab_size = config.decoder.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
@@ -883,9 +885,6 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
pixel_values = pixel_values.to(dtype)
|
||||
return self.encoder(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.decoder
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -385,20 +385,20 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
self.llm_arch_name = config.text_config.architectures[0]
|
||||
self.vision_model = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_model = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.img_context_token_id = None
|
||||
|
||||
@@ -520,8 +520,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
|
||||
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
||||
|
||||
num_patches = image_input["num_patches"]
|
||||
@@ -556,9 +554,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
@@ -609,7 +604,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@@ -417,7 +417,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
return IMAGE_TOKEN
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
@@ -427,20 +427,22 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config: PretrainedConfig = config
|
||||
self.llm = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.get_text_config()),
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.visual_tokenizer_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.llm = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.get_text_config()),
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
self.vte = VisualEmbedding(
|
||||
self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.visual_tokenizer_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
)
|
||||
self.vte = VisualEmbedding(
|
||||
self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size
|
||||
)
|
||||
|
||||
text_model_type = self.config.get_text_config().model_type
|
||||
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||
@@ -546,12 +548,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.llm.compute_logits(hidden_states)
|
||||
return logits
|
||||
return self.llm.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
|
||||
@@ -451,6 +451,15 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
|
||||
dummy_inputs=Ovis2_5DummyInputsBuilder,
|
||||
)
|
||||
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return IMAGE_TOKEN
|
||||
if modality.startswith("video"):
|
||||
return VIDEO_TOKEN
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -458,20 +467,22 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config: PretrainedConfig = config
|
||||
self.llm = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.vit_config,
|
||||
visual_vocab_size=config.visual_vocab_size,
|
||||
multimodal_config=multimodal_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.llm = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.vit_config,
|
||||
visual_vocab_size=config.visual_vocab_size,
|
||||
multimodal_config=multimodal_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
)
|
||||
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
|
||||
|
||||
text_model_type = self.config.get_text_config().model_type
|
||||
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||
@@ -650,12 +661,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.llm.compute_logits(hidden_states)
|
||||
return logits
|
||||
return self.llm.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
|
||||
@@ -999,6 +999,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -1008,22 +1015,24 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.visual = SiglipVisionModel(
|
||||
config=config.vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
self.mlp_AR = Projector(config, config.vision_config)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.visual = SiglipVisionModel(
|
||||
config=config.vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
self.mlp_AR = Projector(config, config.vision_config)
|
||||
|
||||
self.language_model = Ernie4_5ForCausalLM(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = Ernie4_5ForCausalLM(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
for layer in self.language_model.model.layers:
|
||||
if not isinstance(layer, PPMissingLayer):
|
||||
layer.self_attn.rotary_emb.is_neox_style = True
|
||||
for layer in language_model.model.layers:
|
||||
if not isinstance(layer, PPMissingLayer):
|
||||
layer.self_attn.rotary_emb.is_neox_style = True
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -1151,9 +1160,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> PaddleOCRImagePixelInputs | None:
|
||||
@@ -1180,29 +1186,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
is_multimodal = kwargs.pop("is_multimodal", None)
|
||||
handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.language_model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def encode_image(
|
||||
self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -295,30 +295,32 @@ class PaliGemmaForConditionalGeneration(
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.vision_tower = SiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
projection_dim=config.vision_config.projection_dim,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = SiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
projection_dim=config.vision_config.projection_dim,
|
||||
)
|
||||
|
||||
if config.text_config.model_type == "gemma":
|
||||
config.text_config.architectures = ["GemmaForCausalLM"]
|
||||
else:
|
||||
config.text_config.architectures = ["Gemma2ForCausalLM"]
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -367,7 +369,6 @@ class PaliGemmaForConditionalGeneration(
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
pixel_values = image_input["data"]
|
||||
image_features = self._image_pixels_to_features(
|
||||
self.vision_tower,
|
||||
@@ -376,9 +377,6 @@ class PaliGemmaForConditionalGeneration(
|
||||
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -586,31 +586,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
self.multimodal_config = multimodal_config
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.embed_tokens"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.embed_tokens"),
|
||||
)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
# The prefix is empty intentionally because default prefix of
|
||||
# LlamaForCausalLM is "model"
|
||||
prefix="",
|
||||
# We don't directly initialize vLLM's LlamaForCausalLM so we
|
||||
# can automatically apply embedding wrapper if this model is
|
||||
# initialized as an embedding model
|
||||
architectures=["LlamaForCausalLM"],
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
# The prefix is empty intentionally because default prefix of
|
||||
# LlamaForCausalLM is "model"
|
||||
prefix="",
|
||||
# We don't directly initialize vLLM's LlamaForCausalLM so we
|
||||
# can automatically apply embedding wrapper if this model is
|
||||
# initialized as an embedding model
|
||||
architectures=["LlamaForCausalLM"],
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -652,17 +652,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
|
||||
image_embeds = self.vision_embed_tokens(
|
||||
image_input["pixel_values"], image_input["image_sizes"]
|
||||
)
|
||||
|
||||
return image_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -1027,12 +1027,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
# Tensor/Pipeline parallel not supported for now.
|
||||
assert get_pp_group().world_size == 1, "pipeline parallel is not supported"
|
||||
|
||||
self.vision_encoder = Phi4MMImageEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
prefix="model.vision_embed_tokens",
|
||||
model_dir=config._name_or_path,
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vision_encoder = Phi4MMImageEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
prefix="model.vision_embed_tokens",
|
||||
model_dir=config._name_or_path,
|
||||
)
|
||||
|
||||
if isinstance(config.embd_layer["audio_embd_layer"], dict):
|
||||
embedding_config = {
|
||||
@@ -1044,10 +1045,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
"embedding_cls": self.config.embd_layer["embedding_cls"]
|
||||
}
|
||||
|
||||
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
|
||||
self.model = LlamaModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.model = LlamaModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
@@ -1245,6 +1249,3 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
connector=["audio_projection_for_vision", "audio_projection"],
|
||||
tower_model=["vision_encoder", "embed_tokens_extend"],
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -453,15 +453,15 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
]
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
self.language_model = language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
|
||||
)
|
||||
|
||||
# Whether to include the gate_up_proj mapping is determined by
|
||||
# the language model.
|
||||
self.packed_modules_mapping = (
|
||||
self.packed_modules_mapping | self.language_model.packed_modules_mapping
|
||||
)
|
||||
# Whether to include the gate_up_proj mapping is determined by
|
||||
# the language model.
|
||||
self.packed_modules_mapping = (
|
||||
self.packed_modules_mapping | language_model.packed_modules_mapping
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -908,7 +908,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@@ -1104,14 +1104,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
|
||||
@@ -597,14 +597,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
Reference in New Issue
Block a user