[VLM] Fully dynamic prompt replacement in merged input processor (#11199)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-15 01:52:18 +08:00
committed by GitHub
parent 9c3dadd1c9
commit 93abf23a64
12 changed files with 565 additions and 506 deletions

View File

@@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict,
MultiModalProcessingMetadata,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
processor = ctx.get_hf_processor()
image_processor = processor.image_processor # type: ignore
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
return image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return num_tokens
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}
class Phi3VProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(
self,
@@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor):
processed_outputs['input_ids'] = token_ids
return processed_outputs
def _get_dummy_mm_kwargs(
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
) -> ProcessorInputs:
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):