[Model] Add Ovis2.5 PP support (#23405)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-08-23 01:46:34 +08:00
committed by GitHub
parent 22cf679aad
commit 32d2b4064f
5 changed files with 185 additions and 105 deletions

View File

@@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
@@ -70,6 +70,7 @@ class VisualTokenizer(torch.nn.Module):
visual_vocab_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@@ -77,6 +78,7 @@ class VisualTokenizer(torch.nn.Module):
config=config,
quant_config=quant_config,
prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
@@ -93,31 +95,33 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
model_type = config.model_type
if model_type == "siglip2_navit":
return Siglip2NavitModel(config=config, )
return Siglip2NavitModel(config=config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=use_data_parallel)
raise ValueError(
f"Unsupported visual tokenizer model_type: {model_type}")
@property
def dtype(self):
def dtype(self) -> torch.dtype:
return next(self.head.parameters()).dtype
@property
def device(self):
def device(self) -> torch.device:
return next(self.head.parameters()).device
def tokenize(self, logits):
def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
tokens = torch.softmax(logits, dim=-1,
dtype=torch.float32).to(logits.dtype)
return tokens
def encode(self, pixel_values, grid_thws):
features = self.vit(pixel_values,
grid_thws,
output_hidden_states=True,
return_dict=True)
def encode(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.vit(pixel_values, grid_thws)
# refer to qwen2.5-vl patchmerger
seq_len, _ = features.shape
features = features.reshape(seq_len // (self.config.hidden_stride**2),
@@ -125,7 +129,8 @@ class VisualTokenizer(torch.nn.Module):
return features
def forward(self, pixel_values, grid_thws) -> torch.Tensor:
def forward(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.encode(pixel_values, grid_thws)
logits = self.head(features)
tokens = self.tokenize(logits)
@@ -395,7 +400,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
info=Ovis2_5ProcessingInfo,
dummy_inputs=Ovis2_5DummyInputsBuilder)
class Ovis2_5(nn.Module, SupportsMultiModal):
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -421,9 +426,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
self.make_empty_intermediate_tensors = (
self.get_language_model().make_empty_intermediate_tensors)
def _parse_and_validate_visual_input(
self, is_video,
@@ -567,4 +571,4 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm
return self.llm