[Core] Registry for processing model inputs (#5214)

Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-06-28 20:09:56 +08:00
committed by GitHub
parent 0d0e3a42ac
commit 5cbe8d155c
26 changed files with 784 additions and 398 deletions

View File

@@ -2,10 +2,11 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
from transformers import LlavaConfig
from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@@ -16,10 +17,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip)
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
@@ -83,9 +85,35 @@ class LlavaImageFeatureInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,