[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
@@ -2,10 +2,11 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlavaConfig
|
||||
from transformers import CLIPVisionConfig, LlavaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -16,10 +17,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import get_dummy_image_data
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
|
||||
dummy_seq_data_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
@@ -83,9 +85,35 @@ class LlavaImageFeatureInputs(TypedDict):
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
image_input_type = multimodal_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
mm_data: MultiModalData
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
mm_data = dummy_pixel_data_for_clip(vision_config)
|
||||
elif image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
mm_data = dummy_feature_data_for_clip(vision_config)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
Reference in New Issue
Block a user