[Model] Add base class for LoRA-supported models (#5018)
This commit is contained in:
@@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
@@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
supports_vision = True
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
if self.vlm_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
else:
|
||||
@@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != list(
|
||||
self.vision_language_config.image_input_shape[1:]):
|
||||
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
|
||||
raise ValueError(
|
||||
f"The expected image tensor shape is batch dimension plus "
|
||||
f"{self.vision_language_config.image_input_shape[1:]}. "
|
||||
f"{self.vlm_config.image_input_shape[1:]}. "
|
||||
f"You supplied {data.shape}. "
|
||||
f"If you are using vLLM's entrypoint, make sure your "
|
||||
f"supplied image input is consistent with "
|
||||
@@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
expected_input_type = self.vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
@@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
self.vlm_config.image_token_id)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user