[1/N] Initialize MM components in context managers (A-D) (#32632)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 14:12:42 +08:00
committed by GitHub
parent 4753f3bf69
commit b75e85dede
11 changed files with 240 additions and 268 deletions

View File

@@ -15,9 +15,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
@@ -539,30 +537,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config = vllm_config.quant_config
self.config = config
self.vision_tower = AriaVisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
self.lm_head = ParallelLMHead(
self.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = AriaVisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
with self._mark_language_model(vllm_config):
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
def _parse_and_validate_image_input(
self, **kwargs: object
@@ -618,9 +608,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -654,9 +641,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)

View File

@@ -460,20 +460,21 @@ class AudioFlamingo3ForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = AudioFlamingo3Encoder(
config.audio_config,
)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = AudioFlamingo3Encoder(
config.audio_config,
)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -599,9 +600,6 @@ class AudioFlamingo3ForConditionalGeneration(
current_idx += count
return tuple(grouped_embeddings)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:

View File

@@ -343,21 +343,23 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.vocab_size = config.text_config.vocab_size
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"),
# Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
architectures=["Cohere2ForCausalLM"],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"),
# Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
architectures=["Cohere2ForCausalLM"],
)
@property
def dtype(self):
@@ -410,9 +412,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
},
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -44,6 +44,7 @@ from .interfaces import (
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
TowerMissingLayer,
)
from .siglip import SiglipVisionModel
from .utils import (
@@ -373,12 +374,13 @@ class BagelForConditionalGeneration(
# Initialize language model (Qwen2)
# Pass the llm_config from BagelConfig to initialize Qwen2 properly
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.llm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.llm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
# Initialize vision model (SigLIP) if visual understanding is enabled
if config.visual_und:
@@ -398,34 +400,35 @@ class BagelForConditionalGeneration(
)
vit_config.vision_use_head = False
self.vit_model = SiglipVisionModel(
config=vit_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vit_model"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vit_model = SiglipVisionModel(
config=vit_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vit_model"),
)
# Initialize connector (MLP)
vit_hidden_size = config.vit_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
# Initialize connector (MLP)
vit_hidden_size = config.vit_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.connector = BagelVisionMLP(
in_features=vit_hidden_size,
hidden_features=llm_hidden_size,
out_features=llm_hidden_size,
act_layer=config.connector_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
self.connector = BagelVisionMLP(
in_features=vit_hidden_size,
hidden_features=llm_hidden_size,
out_features=llm_hidden_size,
act_layer=config.connector_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
# Position embedding for vision tokens
self.vit_pos_embed = PositionEmbedding(
max_num_patch_per_side=config.vit_max_num_patch_per_side,
hidden_size=llm_hidden_size,
)
# Position embedding for vision tokens
self.vit_pos_embed = PositionEmbedding(
max_num_patch_per_side=config.vit_max_num_patch_per_side,
hidden_size=llm_hidden_size,
)
else:
self.vit_model = None
self.connector = None
self.vit_pos_embed = None
self.vit_model = TowerMissingLayer("image")
self.connector = TowerMissingLayer("image")
self.vit_pos_embed = TowerMissingLayer("image")
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -502,9 +505,6 @@ class BagelForConditionalGeneration(
return self._process_image_input(image_input)
def get_language_model(self) -> nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,
@@ -540,14 +540,6 @@ class BagelForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights from checkpoint."""
skip_prefixes = []
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
skip_prefixes.append("vit_pos_embed.pos_embed")
# If visual understanding is disabled, skip vision-related weights
if self.vit_model is None:
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"])
# Skip generation-related weights since we only support text2text and image2text
# Filter out all image generation components:
# - 'moe_gen': MoE generation weights
@@ -587,5 +579,6 @@ class BagelForConditionalGeneration(
filtered_weights.append((name, tensor))
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
loader = AutoWeightsLoader(self, skip_prefixes=["vit_pos_embed.pos_embed"])
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -549,31 +549,31 @@ class Blip2ForConditionalGeneration(
+ 1 # include class token
)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(vision_config, quant_config)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = BlipVisionModel(vision_config, quant_config)
self.query_tokens = nn.Parameter(
torch.zeros(
1, config.num_query_tokens, config.qformer_config.hidden_size
)
)
self.qformer = Blip2QFormerModel(
config.qformer_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.qformer",
)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(
config.qformer_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.qformer",
)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -614,8 +614,6 @@ class Blip2ForConditionalGeneration(
return image_features
def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values)
@@ -624,7 +622,6 @@ class Blip2ForConditionalGeneration(
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1)
@@ -635,9 +632,6 @@ class Blip2ForConditionalGeneration(
return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -853,28 +853,30 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = CLIPTextTransformer(
text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"),
)
self.vision_model = CLIPVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
with self._mark_language_model(vllm_config):
self.text_model = CLIPTextTransformer(
text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"),
)
self.text_projection = nn.Linear(
self.text_embed_dim,
self.projection_dim,
bias=False,
)
self.visual_projection = nn.Linear(
self.vision_embed_dim,
self.projection_dim,
bias=False,
)
self.text_projection = nn.Linear(
self.text_embed_dim,
self.projection_dim,
bias=False,
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = CLIPVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.visual_projection = nn.Linear(
self.vision_embed_dim,
self.projection_dim,
bias=False,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@@ -940,9 +942,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
return self.get_image_features(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids(
self,
input_ids: torch.Tensor,

View File

@@ -353,21 +353,23 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vocab_size = config.text_config.vocab_size
self.multi_modal_projector = Cohere2VisionMultiModalProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=config.text_config.architectures,
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Cohere2VisionMultiModalProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=config.text_config.architectures,
)
@property
def dtype(self):
@@ -437,9 +439,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
):
quant_config.modules_to_not_convert.append("vision_tower")
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -383,46 +383,48 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.sam_model = build_sam_vit_b()
clip_vision_config = CLIPVisionConfig(
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
num_hidden_layers=24,
image_size=224,
patch_size=14,
projection_dim=512,
layer_norm_eps=1e-5,
)
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
n_embed = self.projector_config.n_embed
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
# This is a typo in original implementation
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
with self._mark_tower_model(vllm_config, "image"):
self.sam_model = build_sam_vit_b()
clip_vision_config = CLIPVisionConfig(
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
num_hidden_layers=24,
image_size=224,
patch_size=14,
projection_dim=512,
layer_norm_eps=1e-5,
)
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
n_embed = self.projector_config.n_embed
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
# This is a typo in original implementation
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -552,9 +554,6 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -374,37 +374,39 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
self.vision = self._init_vision_module(
self.vision_config, quant_config, maybe_prefix(prefix, "vision")
)
self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
embed_std = 1 / torch.sqrt(
torch.tensor(self.projector_config.n_embed, dtype=torch.float32)
)
if self.tile_tag == "2D":
# <|view_seperator|>, <|\n|>
self.image_newline = nn.Parameter(
torch.randn(self.projector_config.n_embed) * embed_std
)
# This is a typo in original implementation
self.view_seperator = nn.Parameter(
torch.randn(self.projector_config.n_embed) * embed_std
)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
with self._mark_tower_model(vllm_config, "image"):
self.vision = self._init_vision_module(
self.vision_config, quant_config, maybe_prefix(prefix, "vision")
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language"),
)
self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
embed_std = 1 / torch.sqrt(
torch.tensor(self.projector_config.n_embed, dtype=torch.float32)
)
if self.tile_tag == "2D":
# <|view_seperator|>, <|\n|>
self.image_newline = nn.Parameter(
torch.randn(self.projector_config.n_embed) * embed_std
)
# This is a typo in original implementation
self.view_seperator = nn.Parameter(
torch.randn(self.projector_config.n_embed) * embed_std
)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -603,9 +605,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -689,18 +689,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
else:
vision_config = self.config.vision_config
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
with self._mark_language_model(vllm_config):
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -763,9 +766,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
return image_embeds.split(sizes)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
merge_size = self.vision_tower.spatial_merge_size
return num_image_tokens * (merge_size**2)

View File

@@ -83,7 +83,10 @@ class LMMissingLayer(nn.Module):
class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str]) -> None:
def __init__(self, modalities: set[str] | str) -> None:
if isinstance(modalities, str):
modalities = {modalities}
super().__init__()
self.modalities = modalities