[Bugfix] Clean up and fix multi-modal processors (#13012)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-10 18:45:21 +08:00
committed by GitHub
parent fde71262e0
commit 51f0b5f7f6
7 changed files with 124 additions and 154 deletions

View File

@@ -63,18 +63,6 @@ from .utils import (flatten_bn, is_pp_missing_parameter,
logger = init_logger(__name__)
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
# for the time being, these tags are not considered as special at encoding
# time. This may change as VLLMs multimodal API changes in the future.
IMG_START = "<img>"
IMG_END = "</img>"
IMG_PAD = "<imgpad>"
# Image context is fixed at 256 for all images
MAX_QWEN_IMG_TOKENS = 256
# Image normalization params
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
class QwenImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
@@ -622,25 +610,6 @@ class QWenModel(nn.Module):
return hidden_states
def build_normalization_transform(image_size: int) -> transforms.Compose:
"""
Build a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
Args:
image_size: size of the image to be processed for visual embeddings.
Returns:
Callable transform for normalizing and resizing one RGB image.
"""
return transforms.Compose([
transforms.Resize((image_size, image_size),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
@@ -716,16 +685,34 @@ class QWenVLProcessor:
self.config = config
self.tokenizer = tokenizer
if hasattr(self.config, "visual"):
self.image_transform = build_normalization_transform(
config.visual["image_size"])
if vision_config := getattr(self.config, "visual", None):
image_size = vision_config["image_size"]
self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
else:
self.image_transform = None
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
self.img_start_id = special_tokens[IMG_START]
self.img_end_id = special_tokens[IMG_END]
@property
def image_start_tag(self) -> str:
return self.tokenizer.image_start_tag # type: ignore
@property
def image_end_tag(self) -> str:
return self.tokenizer.image_end_tag # type: ignore
@property
def image_pad_tag(self) -> str:
return self.tokenizer.image_pad_tag # type: ignore
def __call__(
self,
@@ -787,7 +774,14 @@ class QWenVLProcessingInfo(BaseProcessingInfo):
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
return MAX_QWEN_IMG_TOKENS
hf_config = self.get_hf_config()
if not (vision_config := getattr(hf_config, "visual", None)):
return 0
image_size = vision_config["image_size"]
patch_size = vision_config["patch_size"]
grid_length = image_size // patch_size // 2
return grid_length * grid_length
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
@@ -798,10 +792,12 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
if not hasattr(hf_config, "visual"):
if not (vision_config := getattr(hf_config, "visual", None)):
return ProcessorInputs(prompt_text="", mm_data={})
vision_config = hf_config.visual
processor = self.info.get_hf_processor()
img_start = processor.image_start_tag
img_end = processor.image_end_tag
target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
@@ -814,7 +810,7 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
}
return ProcessorInputs(
prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
for i in range(1, num_images + 1)),
mm_data=mm_data,
)
@@ -869,13 +865,18 @@ class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
if not hasattr(hf_config, "visual"):
return []
tokenizer = self.info.get_tokenizer()
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
img_start_id = special_tokens[IMG_START]
img_end_id = special_tokens[IMG_END]
img_pad_id = special_tokens[IMG_PAD]
processor = self.info.get_hf_processor()
img_start_id = special_tokens[processor.image_start_tag]
img_end_id = special_tokens[processor.image_end_tag]
img_pad_id = special_tokens[processor.image_pad_tag]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [img_pad_id] * num_image_tokens