[VLM] Calculate maximum number of multi-modal tokens by model (#6121)

This commit is contained in:
Cyrus Leung
2024-07-05 07:37:23 +08:00
committed by GitHub
parent 69ec3ca14c
commit ae96ef8fbd
12 changed files with 265 additions and 95 deletions

View File

@@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
raise NotImplementedError(msg)
def get_max_llava_next_image_tokens(ctx: InputContext):
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448
return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=dummy_height,
input_width=dummy_width,
)
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
@@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):