[Frontend] Separate pooling APIs in offline inference (#11129)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-13 18:40:07 +08:00
committed by GitHub
parent f93bf2b189
commit eeec9e3390
21 changed files with 669 additions and 304 deletions

View File

@@ -46,11 +46,10 @@ from vllm.outputs import (PoolingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@@ -966,9 +965,9 @@ class LLMEngine:
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
outputs: List[PoolingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
seq_group.pooled_data = outputs[0].data
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
@@ -1784,8 +1783,8 @@ class LLMEngine:
num_prompt_tokens_iter)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
if model_output and isinstance(model_output[0], SamplerOutput) and (
model_output[0].spec_decode_worker_metrics is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None