[Model] Update multi-modal processor to support Mantis(LLaVA) model (#10711)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -22,10 +22,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalProcessingMetadata,
|
||||
MultiModalProcessor, PromptReplacement)
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
@@ -163,7 +164,13 @@ def create_metadata_for_llava(
|
||||
}
|
||||
|
||||
|
||||
class LlavaProcessor(MultiModalProcessor):
|
||||
class LlavaProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_llava(ctx),
|
||||
)
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
@@ -193,7 +200,30 @@ class LlavaProcessor(MultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'],
|
||||
return_tensors="pt")
|
||||
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||
|
||||
return MultiModalKwargs(
|
||||
**hf_inputs,
|
||||
is_pixtral=torch.tensor(is_pixtral),
|
||||
)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
@@ -277,10 +307,7 @@ def init_vision_tower_for_llava(
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_llava(ctx),
|
||||
))
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@@ -559,3 +586,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class MantisProcessor(LlavaProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
try:
|
||||
from mantis.models.mllava import MLlavaProcessor
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"You need to `pip install "
|
||||
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
|
||||
"to use this model") from exc
|
||||
|
||||
processor = MLlavaProcessor.from_pretrained(
|
||||
self.ctx.model_config.tokenizer)
|
||||
assert isinstance(processor, ProcessorMixin)
|
||||
return processor
|
||||
|
||||
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user