[Bugfix] Standardize custom HF Processor init (#37289)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -196,8 +196,10 @@ class DeepseekOCRProcessingInfo(BaseProcessingInfo):
|
||||
crop_mode=CROP_MODE,
|
||||
strategy="v1",
|
||||
)
|
||||
|
||||
return self.ctx.get_hf_processor(
|
||||
DeepseekOCRProcessor, **{**kwargs, **v1_processor_config}
|
||||
DeepseekOCRProcessor,
|
||||
**{**v1_processor_config, **kwargs},
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
|
||||
@@ -76,8 +76,10 @@ class DeepseekOCR2ProcessingInfo(BaseProcessingInfo):
|
||||
crop_mode=CROP_MODE,
|
||||
strategy="v2",
|
||||
)
|
||||
|
||||
return self.ctx.get_hf_processor(
|
||||
DeepseekOCRProcessor, **{**kwargs, **v2_processor_config}
|
||||
DeepseekOCRProcessor,
|
||||
**{**v2_processor_config, **kwargs},
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
|
||||
@@ -47,7 +47,10 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||
from vllm.transformers_utils.processors.glm4v import GLM4VProcessor
|
||||
from vllm.transformers_utils.processors.glm4v import (
|
||||
GLM4VImageProcessorFast,
|
||||
GLM4VProcessor,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
|
||||
@@ -387,15 +390,20 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChatGLMConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
|
||||
def get_image_processor(self, **kwargs):
|
||||
config = self.get_hf_config()
|
||||
vision_config = config.vision_config
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
kwargs.setdefault("size", {"width": image_size, "height": image_size})
|
||||
|
||||
return GLM4VImageProcessorFast(**kwargs)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
|
||||
return self.ctx.init_processor(
|
||||
GLM4VProcessor,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**{**kwargs, "image_size": image_size},
|
||||
image_processor=self.get_image_processor(**kwargs),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
|
||||
@@ -44,7 +44,10 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processors.qwen_vl import QwenVLProcessor
|
||||
from vllm.transformers_utils.processors.qwen_vl import (
|
||||
QwenVLImageProcessorFast,
|
||||
QwenVLProcessor,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
@@ -432,15 +435,20 @@ class QwenVLModel(QWenModel):
|
||||
|
||||
|
||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
||||
def get_image_processor(self, **kwargs):
|
||||
config = self.get_hf_config()
|
||||
vision_config = config.visual
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
kwargs.setdefault("size", {"width": image_size, "height": image_size})
|
||||
|
||||
return QwenVLImageProcessorFast(**kwargs)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
||||
return self.ctx.init_processor(
|
||||
QwenVLProcessor,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**{**kwargs, "image_size": image_size},
|
||||
image_processor=self.get_image_processor(**kwargs),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
|
||||
@@ -61,6 +61,10 @@ def get_qwen_vl_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||
|
||||
|
||||
class QwenVLTokenizer(TokenizerLike):
|
||||
image_start_tag: str
|
||||
image_end_tag: str
|
||||
image_pad_tag: str
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> HfTokenizer:
|
||||
tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
|
||||
|
||||
@@ -29,13 +29,8 @@ class GLM4VProcessor(ProcessorMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: GLM4VImageProcessorFast,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
image_size: int,
|
||||
image_processor: GLM4VImageProcessorFast | None = None,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
if image_processor is None:
|
||||
image_processor = GLM4VImageProcessorFast(
|
||||
size={"width": image_size, "height": image_size}
|
||||
)
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@@ -31,25 +31,12 @@ class QwenVLProcessor(ProcessorMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: QwenVLImageProcessorFast,
|
||||
tokenizer: QwenVLTokenizer,
|
||||
image_size: int,
|
||||
image_processor: QwenVLImageProcessorFast | None = None,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
if image_processor is None:
|
||||
image_processor = QwenVLImageProcessorFast(
|
||||
size={"width": image_size, "height": image_size}
|
||||
)
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def image_start_tag(self) -> str:
|
||||
return self.tokenizer.image_start_tag # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def image_end_tag(self) -> str:
|
||||
return self.tokenizer.image_end_tag # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def image_pad_tag(self) -> str:
|
||||
return self.tokenizer.image_pad_tag # type: ignore[attr-defined]
|
||||
self.image_start_tag = tokenizer.image_start_tag
|
||||
self.image_end_tag = tokenizer.image_end_tag
|
||||
self.image_pad_tag = tokenizer.image_pad_tag
|
||||
|
||||
Reference in New Issue
Block a user