[Feature] Add vision language model support. (#3042)

This commit is contained in:
xwjiang2010
2024-03-25 14:16:30 -07:00
committed by GitHub
parent f408d05c52
commit 64172a976c
28 changed files with 936 additions and 94 deletions

View File

@@ -6,7 +6,7 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
@@ -40,6 +40,7 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
@@ -47,6 +48,7 @@ class RayGPUExecutor(ExecutorBase):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group
@@ -181,6 +183,7 @@ class RayGPUExecutor(ExecutorBase):
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)