[Bugfix] Clean up and fix multi-modal processors (#13012)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -63,18 +63,6 @@ from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
|
||||
# for the time being, these tags are not considered as special at encoding
|
||||
# time. This may change as VLLMs multimodal API changes in the future.
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_PAD = "<imgpad>"
|
||||
# Image context is fixed at 256 for all images
|
||||
MAX_QWEN_IMG_TOKENS = 256
|
||||
# Image normalization params
|
||||
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
@@ -622,25 +610,6 @@ class QWenModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
||||
"""
|
||||
Build a normalization transform which can be applied to one or
|
||||
more input images from which we want to extract visual features.
|
||||
|
||||
Args:
|
||||
image_size: size of the image to be processed for visual embeddings.
|
||||
|
||||
Returns:
|
||||
Callable transform for normalizing and resizing one RGB image.
|
||||
"""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_tokenizer_without_image_pad(
|
||||
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
|
||||
@@ -716,16 +685,34 @@ class QWenVLProcessor:
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if hasattr(self.config, "visual"):
|
||||
self.image_transform = build_normalization_transform(
|
||||
config.visual["image_size"])
|
||||
if vision_config := getattr(self.config, "visual", None):
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
else:
|
||||
self.image_transform = None
|
||||
|
||||
special_tokens: dict[str,
|
||||
int] = tokenizer.special_tokens # type: ignore
|
||||
self.img_start_id = special_tokens[IMG_START]
|
||||
self.img_end_id = special_tokens[IMG_END]
|
||||
@property
|
||||
def image_start_tag(self) -> str:
|
||||
return self.tokenizer.image_start_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_end_tag(self) -> str:
|
||||
return self.tokenizer.image_end_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_pad_tag(self) -> str:
|
||||
return self.tokenizer.image_pad_tag # type: ignore
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -787,7 +774,14 @@ class QWenVLProcessingInfo(BaseProcessingInfo):
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
return MAX_QWEN_IMG_TOKENS
|
||||
hf_config = self.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "visual", None)):
|
||||
return 0
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
patch_size = vision_config["patch_size"]
|
||||
grid_length = image_size // patch_size // 2
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||
@@ -798,10 +792,12 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not hasattr(hf_config, "visual"):
|
||||
if not (vision_config := getattr(hf_config, "visual", None)):
|
||||
return ProcessorInputs(prompt_text="", mm_data={})
|
||||
|
||||
vision_config = hf_config.visual
|
||||
processor = self.info.get_hf_processor()
|
||||
img_start = processor.image_start_tag
|
||||
img_end = processor.image_end_tag
|
||||
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@@ -814,7 +810,7 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
|
||||
prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
|
||||
for i in range(1, num_images + 1)),
|
||||
mm_data=mm_data,
|
||||
)
|
||||
@@ -869,13 +865,18 @@ class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not hasattr(hf_config, "visual"):
|
||||
return []
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens: dict[str,
|
||||
int] = tokenizer.special_tokens # type: ignore
|
||||
|
||||
img_start_id = special_tokens[IMG_START]
|
||||
img_end_id = special_tokens[IMG_END]
|
||||
img_pad_id = special_tokens[IMG_PAD]
|
||||
processor = self.info.get_hf_processor()
|
||||
img_start_id = special_tokens[processor.image_start_tag]
|
||||
img_end_id = special_tokens[processor.image_end_tag]
|
||||
img_pad_id = special_tokens[processor.image_pad_tag]
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [img_pad_id] * num_image_tokens
|
||||
|
||||
Reference in New Issue
Block a user