[Core] Rename input data types (#8688)
This commit is contained in:
@@ -9,7 +9,8 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
@@ -460,15 +461,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
# The original model places image tokens at the front
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
|
||||
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
new_token_ids += llm_inputs["prompt_token_ids"]
|
||||
new_token_ids += inputs["prompt_token_ids"]
|
||||
|
||||
new_prompt = llm_inputs.get("prompt")
|
||||
new_prompt = inputs.get("prompt")
|
||||
if new_prompt is not None:
|
||||
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
|
||||
Reference in New Issue
Block a user