[Core] Rename input data types (#8688)

This commit is contained in:
Cyrus Leung
2024-10-16 18:49:37 +08:00
committed by GitHub
parent 1de76a0e55
commit cee711fdbb
32 changed files with 438 additions and 340 deletions

View File

@@ -36,7 +36,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
@@ -256,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_token_counts((0, seq_len))
return SequenceData.from_prompt_token_counts((0, seq_len))
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
@@ -279,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
return seq_data, mm_data
def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(
@@ -297,8 +298,8 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return image_processor. \
get_slice_image_placeholder(image_size, num_image)
prompt = llm_inputs.get("prompt")
token_ids = llm_inputs.get("prompt_token_ids")
prompt = inputs.get("prompt")
token_ids = inputs.get("prompt_token_ids")
if prompt is None:
prompt = tokenizer.decode(token_ids)
@@ -332,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
_build_image_input(ctx, image) for image in images
]
llm_inputs = LLMInputs(
return token_inputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
)
return llm_inputs
def input_mapper_for_minicpmv(ctx: InputContext, data: object):