[Feature] Add vision language model support. (#3042)

This commit is contained in:
xwjiang2010
2024-03-25 14:16:30 -07:00
committed by GitHub
parent f408d05c52
commit 64172a976c
28 changed files with 936 additions and 94 deletions

View File

@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
@@ -240,6 +241,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -252,14 +254,13 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
)
return self.add_request(request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
async def check_health_async(self) -> None:
self.model_executor.check_health()
@@ -486,6 +487,7 @@ class AsyncLLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
@@ -534,7 +536,9 @@ class AsyncLLMEngine:
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request)
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
@@ -545,6 +549,7 @@ class AsyncLLMEngine:
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@@ -560,6 +565,7 @@ class AsyncLLMEngine:
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
@@ -619,6 +625,7 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
async for request_output in stream: