[Model][Bugfix] Implicit model flags and reenable Phi-3-Vision (#5896)

This commit is contained in:
Cyrus Leung
2024-06-28 00:08:10 +08:00
committed by GitHub
parent e9d32d077d
commit 98cf2ed678
14 changed files with 26 additions and 32 deletions

View File

@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput
from .interfaces import SupportsVision
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
@@ -317,18 +318,21 @@ def _image_processor(
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class Phi3VForCausalLM(VisionLanguageModelBase):
class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,
config: PretrainedConfig,
vision_language_config: VisionLanguageConfig,
vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.model = LlamaModel(config, cache_config, quant_config)
self.vision_embed_tokens = Phi3HDImageEmbedding(
vision_language_config, config, self.model.embed_tokens)
vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@@ -338,7 +342,7 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
expected_input_type = self.vision_language_config.image_input_type
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type != ImageInputType.PIXEL_VALUES: