[Bugfix] Fix BAGEL online serving for text and image understanding (#31546)
Signed-off-by: Dylan1229 <yvanphys@gmail.com> Signed-off-by: UED <zxr3611244710@gmail.com> Signed-off-by: mr-ye-cao <yecaoyc2019@gmail.com> Co-authored-by: UED <zxr3611244710@gmail.com> Co-authored-by: mr-ye-cao <yecaoyc2019@gmail.com> Co-authored-by: Mr-Ye-Cao <60802056+Mr-Ye-Cao@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -346,6 +346,13 @@ class BagelForConditionalGeneration(
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return "<|image_pad|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""BAGEL processor for image and text inputs."""
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
@@ -44,12 +45,16 @@ class BagelProcessor(ProcessorMixin):
|
||||
text_inputs = self.tokenizer(text, **kwargs) if text is not None else None
|
||||
|
||||
if pixel_values is not None and text_inputs is not None:
|
||||
text_inputs["pixel_values"] = pixel_values["pixel_values"]
|
||||
return text_inputs
|
||||
# Combine text and image inputs into BatchFeature
|
||||
combined = dict(text_inputs)
|
||||
combined["pixel_values"] = pixel_values["pixel_values"]
|
||||
return BatchFeature(combined)
|
||||
elif pixel_values is not None:
|
||||
return pixel_values
|
||||
elif text_inputs is not None:
|
||||
return BatchFeature(dict(text_inputs))
|
||||
else:
|
||||
return text_inputs
|
||||
return BatchFeature({})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user