[Model] Add base class for LoRA-supported models (#5018)

This commit is contained in:
Cyrus Leung
2024-06-27 16:03:04 +08:00
committed by GitHub
parent d12af207d2
commit 96354d6a29
20 changed files with 270 additions and 75 deletions

View File

@@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
@@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class LlavaForConditionalGeneration(VisionLanguageModelBase):
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self,
config: LlavaConfig,
vision_language_config: VisionLanguageConfig,
vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
super().__init__()
self.config = config
self.vlm_config = vlm_config
if self.vision_language_config.image_input_type == (
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
@@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. "
f"{self.vlm_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
@@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
@@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
self.vlm_config.image_token_id)
input_ids = None
else: