[VLM] Calculate maximum number of multi-modal tokens by model (#6121)
This commit is contained in:
@@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
|
||||
return get_clip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
def dummy_seq_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
seq_len: int,
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
get_max_clip_image_tokens, input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
@@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
LlavaImageInputs = LlavaImagePixelInputs
|
||||
|
||||
|
||||
def get_max_llava_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return get_max_clip_image_tokens(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
@@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
@@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def get_max_llava_next_image_tokens(ctx: InputContext):
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
dummy_height = dummy_width = 448
|
||||
|
||||
return get_llava_next_image_feature_size(
|
||||
ctx.get_hf_config(LlavaNextConfig),
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
@@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
@@ -321,6 +321,17 @@ def get_phi3v_image_feature_size(
|
||||
+ (new_height // 336 + 1) * 12
|
||||
|
||||
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext):
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
dummy_height, dummy_width = 8000, 50
|
||||
|
||||
return get_phi3v_image_feature_size(
|
||||
ctx.get_hf_config(PretrainedConfig),
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
dummy_height, dummy_width = 8000, 50
|
||||
@@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
Reference in New Issue
Block a user