[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
This commit is contained in:
@@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
MultiModalData, PoolerOutput, SamplerOutput,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
@@ -169,7 +172,8 @@ class LLMEngine:
|
||||
load_config=load_config,
|
||||
)
|
||||
|
||||
self._initialize_kv_caches()
|
||||
if not self.model_config.embedding_mode:
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
if is_usage_stats_enabled():
|
||||
@@ -354,7 +358,7 @@ class LLMEngine:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -370,7 +374,8 @@ class LLMEngine:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
params: Parameters for sampling or pooling. SamplingParams
|
||||
for text generation. PoolingParams for pooling.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
@@ -404,13 +409,6 @@ class LLMEngine:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
if (sampling_params.logprobs
|
||||
and sampling_params.logprobs > max_logprobs) or (
|
||||
sampling_params.prompt_logprobs
|
||||
and sampling_params.prompt_logprobs > max_logprobs):
|
||||
raise ValueError(f"Cannot request more than "
|
||||
f"{max_logprobs} logprobs.")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = self.encode_request(
|
||||
@@ -432,6 +430,50 @@ class LLMEngine:
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
eos_token_id, lora_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
if (sampling_params.logprobs
|
||||
and sampling_params.logprobs > max_logprobs) or (
|
||||
sampling_params.prompt_logprobs
|
||||
and sampling_params.prompt_logprobs > max_logprobs):
|
||||
raise ValueError(f"Cannot request more than "
|
||||
f"{max_logprobs} logprobs.")
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
@@ -443,11 +485,35 @@ class LLMEngine:
|
||||
self.generation_config_fields)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time, lora_request, multi_modal_data)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
return seq_group
|
||||
|
||||
def _create_sequence_group_with_pooling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
pooling_params = pooling_params.clone()
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
pooling_params=pooling_params)
|
||||
return seq_group
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a request(s) with the given ID.
|
||||
@@ -484,13 +550,25 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _process_sequence_group_outputs(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
outputs: List[EmbeddingSequenceGroupOutput],
|
||||
) -> None:
|
||||
seq_group.embeddings = outputs[0].embeddings
|
||||
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(
|
||||
self,
|
||||
output: List[SamplerOutput],
|
||||
output: List[Union[SamplerOutput, PoolerOutput]],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> List[RequestOutput]:
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
@@ -510,6 +588,9 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
if self.model_config.embedding_mode:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
continue
|
||||
|
||||
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||
if seq_group_meta.do_sample:
|
||||
@@ -519,18 +600,19 @@ class LLMEngine:
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
for scheduled_seq_group in scheduled_seq_groups:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
for seq_group in ignored_seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
.. figure:: https://i.imgur.com/sv2HssD.png
|
||||
@@ -570,7 +652,7 @@ class LLMEngine:
|
||||
>>> while True:
|
||||
>>> if example_inputs:
|
||||
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
||||
>>> engine.add_request(str(req_id), prompt, sampling_params)
|
||||
>>> engine.add_request(str(req_id),prompt,sampling_params)
|
||||
>>>
|
||||
>>> # continue the request processing
|
||||
>>> request_outputs = engine.step()
|
||||
@@ -637,12 +719,15 @@ class LLMEngine:
|
||||
|
||||
# KV Cache Usage in %
|
||||
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
gpu_cache_usage_sys = 0.
|
||||
if num_total_gpu is not None:
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
|
||||
)
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
|
||||
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||
cpu_cache_usage_sys = 0.
|
||||
if num_total_cpu > 0:
|
||||
if num_total_cpu is not None and num_total_cpu > 0:
|
||||
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
||||
)
|
||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
@@ -716,8 +801,10 @@ class LLMEngine:
|
||||
seq.get_output_len()
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
])
|
||||
best_of_requests.append(seq_group.sampling_params.best_of)
|
||||
n_requests.append(seq_group.sampling_params.n)
|
||||
if seq_group.sampling_params is not None:
|
||||
best_of_requests.append(
|
||||
seq_group.sampling_params.best_of)
|
||||
n_requests.append(seq_group.sampling_params.n)
|
||||
finished_reason_requests.extend([
|
||||
SequenceStatus.get_finished_reason(seq.status)
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
|
||||
Reference in New Issue
Block a user