[Model] Update multi-modal processor to support Mantis(LLaVA) model (#10711)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-08 01:10:05 +08:00
committed by GitHub
parent 1c768fe537
commit 39e227c7ae
14 changed files with 175 additions and 78 deletions

View File

@@ -22,10 +22,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (InputProcessingContext,
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
MultiModalProcessor, PromptReplacement)
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@@ -163,7 +164,13 @@ def create_metadata_for_llava(
}
class LlavaProcessor(MultiModalProcessor):
class LlavaProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
)
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
@@ -193,7 +200,30 @@ class LlavaProcessor(MultiModalProcessor):
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
hf_processor = self._get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'],
return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
class LlavaLikeConfig(Protocol):
@@ -277,10 +307,7 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
))
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@@ -559,3 +586,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
class MantisProcessor(LlavaProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc
processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, ProcessorMixin)
return processor
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass