[Model] Use merge_by_field_config for MM models (Llava family) (#26280)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-06 17:45:26 +08:00
committed by GitHub
parent 391612e78b
commit 19a00eb210
9 changed files with 155 additions and 229 deletions

View File

@@ -371,6 +371,115 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=4,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
@@ -505,115 +614,6 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa
) )
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=4,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

View File

@@ -57,7 +57,6 @@ from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@@ -507,6 +506,8 @@ def init_vision_tower_for_llava(
dummy_inputs=LlavaDummyInputsBuilder, dummy_inputs=LlavaDummyInputsBuilder,
) )
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
@@ -592,37 +593,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=pixel_values,
) )
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True), pixel_values=pixel_values,
resolve_bindings={"h": expected_h, "w": expected_w}, resolve_bindings={"h": expected_h, "w": expected_w},
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
raise ValueError("Pixtral-HF does not support image_embeds.") raise ValueError("Pixtral-HF does not support image_embeds.")
return LlavaImageEmbeddingInputs( return LlavaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=image_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@@ -34,7 +34,6 @@ from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@@ -222,6 +221,8 @@ class LlavaNextMultiModalProcessor(
dummy_inputs=LlavaDummyInputsBuilder, dummy_inputs=LlavaDummyInputsBuilder,
) )
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52 # mapping for new names in checkpoint saved after transformers v4.52
@@ -302,21 +303,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=flatten_bn(pixel_values), pixel_values=pixel_values,
image_sizes=flatten_bn(image_sizes, concat=True), image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": expected_h, "h": expected_h,
"w": expected_w, "w": expected_w,
@@ -324,14 +315,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError(
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
)
return LlavaNextImageEmbeddingInputs( return LlavaNextImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@@ -51,14 +51,13 @@ from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TensorSchema): class LlavaNextVideoPixelInputs(TensorSchema):
""" """
Dimensions: Dimensions:
- bs: Batch size - bn: Batch size * number of videos
- nv: Number of videos - f: Number of frames
- nf: Number of frames - c: Number of channels (3)
- nc: Number of channels (3)
- h: Height of each frame - h: Height of each frame
- w: Width of each frame - w: Width of each frame
Note that `num_frames` may be different for each batch, in which case Note that `f` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor. the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch. Note that it only supports one video input for one batch.
@@ -66,9 +65,9 @@ class LlavaNextVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"] = "pixel_values_videos" type: Literal["pixel_values_videos"] = "pixel_values_videos"
data: Annotated[ pixel_values_videos: Annotated[
Union[torch.Tensor, list[torch.Tensor]], Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bs", "nv", "nf", 3, "h", "w"), TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
] ]
@@ -300,6 +299,8 @@ class LlavaNextMultiModalProjector(nn.Module):
dummy_inputs=LlavaNextVideoDummyInputsBuilder, dummy_inputs=LlavaNextVideoDummyInputsBuilder,
) )
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52 # mapping for new names in checkpoint saved after transformers v4.52
@@ -371,7 +372,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextVideoPixelInputs( return LlavaNextVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
data=pixel_values_videos, pixel_values_videos=pixel_values_videos,
resolve_bindings={ resolve_bindings={
"h": expected_h, "h": expected_h,
"w": expected_w, "w": expected_w,
@@ -396,19 +397,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
assert self.vision_tower is not None assert self.vision_tower is not None
video_pixels = inputs["data"] video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor): if isinstance(video_pixels, torch.Tensor):
# TODO: support multiple videos per input bn, f, c, h, w = video_pixels.shape
b, num_videos, num_frames, c, h, w = video_pixels.shape stacked_pixels = video_pixels.view(bn * f, c, h, w)
assert num_videos == 1
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, h, w)
stacked_embeddings = self._video_pixels_to_features( stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels self.vision_tower, stacked_pixels
) )
embeds = stacked_embeddings.view( embeds = stacked_embeddings.view(bn, f, *stacked_embeddings.shape[1:])
b, num_frames, *stacked_embeddings.shape[1:]
)
elif is_list_of(video_pixels, torch.Tensor): elif is_list_of(video_pixels, torch.Tensor):
frames_per_videos = [v.shape[0] for v in video_pixels] frames_per_videos = [v.shape[0] for v in video_pixels]

View File

@@ -44,7 +44,6 @@ from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@@ -62,7 +61,7 @@ class LlavaOnevisionVideoPixelInputs(TensorSchema):
- h: Height - h: Height
- w: Width - w: Width
Note that `num_videos` may be different for each batch, and 'num_frames' Note that `f` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a may be different for each video, in which case the data is passed as a
list instead of a batched tensor. list instead of a batched tensor.
""" """
@@ -480,6 +479,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
dummy_inputs=LlavaOnevisionDummyInputsBuilder, dummy_inputs=LlavaOnevisionDummyInputsBuilder,
) )
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52 # mapping for new names in checkpoint saved after transformers v4.52
@@ -539,20 +540,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
return LlavaOnevisionImagePixelInputs( return LlavaOnevisionImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=flatten_bn(pixel_values), pixel_values=pixel_values,
image_sizes=flatten_bn(image_sizes, concat=True), image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": self.config.vision_config.image_size, "h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size, "w": self.config.vision_config.image_size,
@@ -560,14 +551,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError(
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
)
return LlavaOnevisionImageEmbeddingInputs( return LlavaOnevisionImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
@@ -586,15 +572,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
if pixel_values_videos is None: if pixel_values_videos is None:
return None return None
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}"
)
return LlavaOnevisionVideoPixelInputs( return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values_videos=flatten_bn(pixel_values_videos), pixel_values_videos=pixel_values_videos,
resolve_bindings={ resolve_bindings={
"h": self.config.vision_config.image_size, "h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size, "w": self.config.vision_config.image_size,

View File

@@ -32,7 +32,6 @@ from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@@ -180,6 +179,8 @@ class MiniMaxVL01MultiModalProcessor(
dummy_inputs=MiniMaxVL01DummyInputsBuilder, dummy_inputs=MiniMaxVL01DummyInputsBuilder,
) )
class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
@@ -338,32 +339,16 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
return None return None
if pixel_values is not None and image_sizes is not None: if pixel_values is not None and image_sizes is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
return MiniMaxVL01ImagePixelInputs( return MiniMaxVL01ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=flatten_bn(pixel_values), pixel_values=pixel_values,
image_sizes=flatten_bn(image_sizes, concat=True), image_sizes=image_sizes,
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return MiniMaxVL01ImageEmbeddingInputs( return MiniMaxVL01ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=image_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@@ -52,7 +52,6 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@@ -424,6 +423,8 @@ def init_vision_tower_for_llava(
class Mistral3ForConditionalGeneration( class Mistral3ForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
): ):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
@@ -510,15 +511,9 @@ class Mistral3ForConditionalGeneration(
if pixel_values is None and image_embeds is None: if pixel_values is None and image_embeds is None:
return None return None
assert pixel_values is not None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
return Mistral3ImagePixelInputs( return Mistral3ImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=pixel_values,
) )
def _process_image_input( def _process_image_input(

View File

@@ -64,7 +64,7 @@ from vllm.transformers_utils.tokenizer import (
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix from .utils import init_vllm_registered_model, maybe_prefix
from .vision import ( from .vision import (
VisionEncoderInfo, VisionEncoderInfo,
VisionFeatureSelectStrategy, VisionFeatureSelectStrategy,
@@ -365,6 +365,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
dummy_inputs=PixtralDummyInputsBuilder, dummy_inputs=PixtralDummyInputsBuilder,
) )
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@@ -424,7 +426,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return PixtralImagePixelInputs( return PixtralImagePixelInputs(
type="pixel_values", type="pixel_values",
images=flatten_bn(images), images=images,
) )
def _process_image_input( def _process_image_input(

View File

@@ -49,7 +49,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@@ -404,6 +403,8 @@ def init_vision_tower_for_tarsier(
dummy_inputs=TarsierDummyInputsBuilder, dummy_inputs=TarsierDummyInputsBuilder,
) )
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
@@ -467,25 +468,15 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
return TarsierImagePixelInputs( return TarsierImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True), pixel_values=pixel_values,
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return TarsierImageEmbeddingInputs( return TarsierImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=image_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")