[Misc] Automatically resolve HF processor init kwargs (#22005)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-01 13:44:10 +08:00
committed by GitHub
parent ad57f23f6a
commit 82de9b9d46
40 changed files with 334 additions and 727 deletions

View File

@@ -7,9 +7,8 @@
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, TypeVar, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
@@ -232,7 +231,7 @@ def image_to_pixel_values_skyworkr1v(
return pixel_values
class BaseSkyworkR1VProcessor(ABC):
class SkyworkR1VProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
@@ -279,17 +278,18 @@ class BaseSkyworkR1VProcessor(ABC):
self.use_thumbnail: bool = config.use_thumbnail
@property
@abstractmethod
def image_token_id(self) -> int:
raise NotImplementedError
return self.tokenizer.get_vocab()[IMG_CONTEXT]
@abstractmethod
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
raise NotImplementedError
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num(
self,
@@ -426,35 +426,15 @@ class BaseSkyworkR1VProcessor(ABC):
}
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
@abstractmethod
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseSkyworkR1VProcessor:
raise NotImplementedError
def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -464,7 +444,7 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[BaseSkyworkR1VProcessor],
processor: Optional[SkyworkR1VProcessor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
@@ -500,10 +480,8 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
class SkyworkR1VDummyInputsBuilder(
BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -527,7 +505,8 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
}
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
class SkyworkR1VMultiModalProcessor(
BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
def _call_hf_processor(
self,
@@ -617,31 +596,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
]
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> SkyworkR1VProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor,
info=SkyworkR1VProcessingInfo,