[Misc] Automatically resolve HF processor init kwargs (#22005)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,9 +7,8 @@
|
||||
# Copyright (c) 2025 Skywork
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, TypeVar, Union
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -232,7 +231,7 @@ def image_to_pixel_values_skyworkr1v(
|
||||
return pixel_values
|
||||
|
||||
|
||||
class BaseSkyworkR1VProcessor(ABC):
|
||||
class SkyworkR1VProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
@@ -279,17 +278,18 @@ class BaseSkyworkR1VProcessor(ABC):
|
||||
self.use_thumbnail: bool = config.use_thumbnail
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def image_token_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
@abstractmethod
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> PromptUpdateDetails[str]:
|
||||
raise NotImplementedError
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
@@ -426,35 +426,15 @@ class BaseSkyworkR1VProcessor(ABC):
|
||||
}
|
||||
|
||||
|
||||
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
|
||||
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
|
||||
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
@abstractmethod
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
**kwargs: object,
|
||||
) -> BaseSkyworkR1VProcessor:
|
||||
raise NotImplementedError
|
||||
def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
|
||||
return self.ctx.init_processor(
|
||||
SkyworkR1VProcessor,
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
@@ -464,7 +444,7 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[BaseSkyworkR1VProcessor],
|
||||
processor: Optional[SkyworkR1VProcessor],
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
@@ -500,10 +480,8 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
return largest_feature_pinpoint
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
|
||||
|
||||
|
||||
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
class SkyworkR1VDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@@ -527,7 +505,8 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
}
|
||||
|
||||
|
||||
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
class SkyworkR1VMultiModalProcessor(
|
||||
BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -617,31 +596,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
]
|
||||
|
||||
|
||||
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
**kwargs: object,
|
||||
) -> SkyworkR1VProcessor:
|
||||
if min_dynamic_patch is not None:
|
||||
kwargs["min_dynamic_patch"] = min_dynamic_patch
|
||||
if max_dynamic_patch is not None:
|
||||
kwargs["max_dynamic_patch"] = max_dynamic_patch
|
||||
if dynamic_image_size is not None:
|
||||
kwargs["dynamic_image_size"] = dynamic_image_size
|
||||
|
||||
return self.ctx.init_processor(
|
||||
SkyworkR1VProcessor,
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
SkyworkR1VMultiModalProcessor,
|
||||
info=SkyworkR1VProcessingInfo,
|
||||
|
||||
Reference in New Issue
Block a user