[Model] Add base class for LoRA-supported models (#5018)
This commit is contained in:
@@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput, SequenceData
|
||||
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -106,19 +106,21 @@ def _image_pixel_processor(
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||
class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
supports_vision = True
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
super().__init__()
|
||||
|
||||
# Update the type annotation from that of its superclass
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
if self.vlm_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
else:
|
||||
@@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
||||
_, num_channels, _, _ = self.vision_language_config.image_input_shape
|
||||
_, num_channels, _, _ = self.vlm_config.image_input_shape
|
||||
|
||||
# Note that this is different from that of vLLM vision_language_config
|
||||
# since the image is resized by the HuggingFace preprocessor
|
||||
@@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
expected_input_type = self.vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
@@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
self.vlm_config.image_token_id)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user