Files
vllm/vllm/entrypoints/pooling/typing.py
2026-03-16 19:55:53 -04:00

87 lines
2.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any, Generic, TypeAlias, TypeVar
from fastapi import Request
from pydantic import ConfigDict
from vllm import PoolingRequestOutput
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedRequest,
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
from vllm.inputs import ProcessorInputs
from vllm.lora.request import LoRARequest
PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest
| ClassificationCompletionRequest
| PoolingCompletionRequest
)
PoolingChatLikeRequest: TypeAlias = (
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)
AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest
| PoolingChatLikeRequest
| IOProcessorRequest
| RerankRequest
| ScoreRequest
| CohereEmbedRequest
)
AnyPoolingResponse: TypeAlias = (
ClassificationResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| PoolingResponse
| ScoreResponse
)
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
request: PoolingRequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None
prompt_request_ids: list[str] | None = None
intermediates: Any | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)