Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
This commit is contained in:
@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
@@ -46,7 +46,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
|
||||
scatter_patch_features, select_patch_features)
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@@ -67,6 +68,14 @@ class PixtralImagePixelInputs(TypedDict):
|
||||
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
"""
|
||||
@@ -135,8 +144,11 @@ class PixtralProcessorAdapter:
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
images_embed_is_patch = list[torch.Tensor]()
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
@@ -145,10 +157,12 @@ class PixtralProcessorAdapter:
|
||||
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||
|
||||
return {
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
"embed_is_patch": images_embed_is_patch,
|
||||
}
|
||||
|
||||
|
||||
@@ -199,7 +213,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (image_width, image_height)))
|
||||
|
||||
return ncols * nrows
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_hf_processor().image_processor
|
||||
@@ -249,7 +263,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(images=MultiModalFieldConfig.batched("image"))
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -273,7 +290,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
|
||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
||||
return tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@@ -364,9 +381,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of images. "
|
||||
f"Got type: {type(images)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@@ -402,7 +427,12 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -414,7 +444,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.vision_args.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
@@ -933,7 +963,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
return ncols * nrows
|
||||
|
||||
# Consider the image_break_token
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
image_size = self.get_image_size()
|
||||
|
||||
Reference in New Issue
Block a user