723 lines
26 KiB
Python
723 lines
26 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
import asyncio
|
||
import contextlib
|
||
import json
|
||
import time
|
||
from collections.abc import AsyncGenerator, Mapping
|
||
from dataclasses import dataclass, field
|
||
from http import HTTPStatus
|
||
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||
|
||
import numpy as np
|
||
from fastapi import Request
|
||
from openai.types.responses import ToolChoiceFunction
|
||
from pydantic import ConfigDict, TypeAdapter, ValidationError
|
||
from starlette.datastructures import Headers
|
||
|
||
import vllm.envs as envs
|
||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||
from vllm.config import ModelConfig
|
||
from vllm.engine.protocol import EngineClient
|
||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||
from vllm.entrypoints.logger import RequestLogger
|
||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||
BatchChatCompletionRequest,
|
||
ChatCompletionNamedToolChoiceParam,
|
||
ChatCompletionRequest,
|
||
ChatCompletionResponse,
|
||
)
|
||
from vllm.entrypoints.openai.completion.protocol import (
|
||
CompletionRequest,
|
||
CompletionResponse,
|
||
)
|
||
from vllm.entrypoints.openai.engine.protocol import (
|
||
ErrorResponse,
|
||
FunctionCall,
|
||
FunctionDefinition,
|
||
GenerationError,
|
||
)
|
||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||
TranscriptionRequest,
|
||
TranscriptionResponse,
|
||
TranslationRequest,
|
||
)
|
||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||
DetokenizeRequest,
|
||
TokenizeChatRequest,
|
||
TokenizeCompletionRequest,
|
||
TokenizeResponse,
|
||
)
|
||
from vllm.entrypoints.utils import create_error_response
|
||
from vllm.inputs import EngineInput, PromptType
|
||
from vllm.logger import init_logger
|
||
from vllm.logprobs import Logprob, PromptLogprobs
|
||
from vllm.lora.request import LoRARequest
|
||
from vllm.outputs import CompletionOutput, RequestOutput
|
||
from vllm.renderers import ChatParams, TokenizeParams
|
||
from vllm.renderers.inputs.preprocess import (
|
||
extract_prompt_components,
|
||
extract_prompt_len,
|
||
)
|
||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||
from vllm.tokenizers import TokenizerLike
|
||
from vllm.tool_parsers import ToolParser
|
||
from vllm.tracing import (
|
||
contains_trace_headers,
|
||
extract_trace_headers,
|
||
log_tracing_disabled_warning,
|
||
)
|
||
from vllm.utils import random_uuid
|
||
from vllm.utils.async_utils import collect_from_async_generator
|
||
|
||
logger = init_logger(__name__)
|
||
|
||
|
||
class RendererRequest(Protocol):
|
||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||
raise NotImplementedError
|
||
|
||
|
||
class RendererChatRequest(RendererRequest, Protocol):
|
||
def build_chat_params(
|
||
self,
|
||
default_template: str | None,
|
||
default_template_content_format: ChatTemplateContentFormatOption,
|
||
) -> ChatParams:
|
||
raise NotImplementedError
|
||
|
||
|
||
CompletionLikeRequest: TypeAlias = (
|
||
CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest
|
||
)
|
||
|
||
ChatLikeRequest: TypeAlias = (
|
||
ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest
|
||
)
|
||
|
||
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
|
||
|
||
AnyRequest: TypeAlias = (
|
||
CompletionLikeRequest
|
||
| ChatLikeRequest
|
||
| SpeechToTextRequest
|
||
| ResponsesRequest
|
||
| GenerateRequest
|
||
)
|
||
|
||
AnyResponse: TypeAlias = (
|
||
CompletionResponse
|
||
| ChatCompletionResponse
|
||
| TranscriptionResponse
|
||
| TokenizeResponse
|
||
| GenerateResponse
|
||
)
|
||
|
||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||
|
||
|
||
@dataclass(kw_only=True)
|
||
class ServeContext(Generic[RequestT]):
|
||
request: RequestT
|
||
raw_request: Request | None = None
|
||
model_name: str
|
||
request_id: str
|
||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||
lora_request: LoRARequest | None = None
|
||
engine_inputs: list[EngineInput] | None = None
|
||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||
|
||
|
||
class OpenAIServing:
|
||
request_id_prefix: ClassVar[str] = """
|
||
A short string prepended to every request’s ID.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
engine_client: EngineClient,
|
||
models: OpenAIServingModels,
|
||
*,
|
||
request_logger: RequestLogger | None,
|
||
return_tokens_as_token_ids: bool = False,
|
||
):
|
||
super().__init__()
|
||
|
||
self.engine_client = engine_client
|
||
self.models = models
|
||
|
||
self.request_logger = request_logger
|
||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||
|
||
self.model_config = engine_client.model_config
|
||
self.renderer = engine_client.renderer
|
||
self.input_processor = engine_client.input_processor
|
||
|
||
async def beam_search(
|
||
self,
|
||
prompt: EngineInput,
|
||
request_id: str,
|
||
params: BeamSearchParams,
|
||
lora_request: LoRARequest | None = None,
|
||
trace_headers: Mapping[str, str] | None = None,
|
||
) -> AsyncGenerator[RequestOutput, None]:
|
||
beam_width = params.beam_width
|
||
max_tokens = params.max_tokens
|
||
ignore_eos = params.ignore_eos
|
||
temperature = params.temperature
|
||
length_penalty = params.length_penalty
|
||
include_stop_str_in_output = params.include_stop_str_in_output
|
||
|
||
tokenizer = self.renderer.get_tokenizer()
|
||
eos_token_id = tokenizer.eos_token_id
|
||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||
|
||
if prompt["type"] == "embeds":
|
||
raise NotImplementedError("Embedding prompt not supported for beam search")
|
||
|
||
# Extract prompt tokens and text based on model type
|
||
decoder_prompt = (
|
||
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
|
||
)
|
||
prompt_text = decoder_prompt.get("prompt")
|
||
prompt_token_ids = decoder_prompt["prompt_token_ids"]
|
||
|
||
tokenized_length = len(prompt_token_ids)
|
||
|
||
logprobs_num = 2 * beam_width
|
||
sampling_params = SamplingParams(
|
||
logprobs=logprobs_num,
|
||
max_tokens=1,
|
||
temperature=temperature,
|
||
)
|
||
all_beams = [
|
||
BeamSearchSequence(
|
||
orig_prompt=prompt,
|
||
tokens=prompt_token_ids,
|
||
cum_logprob=0,
|
||
logprobs=[],
|
||
lora_request=lora_request,
|
||
)
|
||
]
|
||
completed = []
|
||
|
||
for _ in range(max_tokens):
|
||
tasks = []
|
||
request_id_batch = f"{request_id}-{random_uuid()}"
|
||
|
||
for i, beam in enumerate(all_beams):
|
||
prompt_item = beam.get_prompt()
|
||
lora_request_item = beam.lora_request
|
||
request_id_item = f"{request_id_batch}-beam-{i}"
|
||
task = asyncio.create_task(
|
||
collect_from_async_generator(
|
||
self.engine_client.generate(
|
||
prompt_item,
|
||
sampling_params,
|
||
request_id_item,
|
||
lora_request=lora_request_item,
|
||
trace_headers=trace_headers,
|
||
)
|
||
)
|
||
)
|
||
tasks.append(task)
|
||
|
||
output = [x[0] for x in await asyncio.gather(*tasks)]
|
||
|
||
new_beams = []
|
||
# Store all new tokens generated by beam
|
||
all_beams_token_id = []
|
||
# Store the cumulative probability of all tokens
|
||
# generated by beam search
|
||
all_beams_logprob = []
|
||
# Iterate through all beam inference results
|
||
for i, result in enumerate(output):
|
||
current_beam = all_beams[i]
|
||
|
||
# check for error finish reason and abort beam search
|
||
if result.outputs[0].finish_reason == "error":
|
||
# yield error output and terminate beam search
|
||
yield RequestOutput(
|
||
request_id=request_id,
|
||
prompt=prompt_text,
|
||
outputs=[
|
||
CompletionOutput(
|
||
index=0,
|
||
text="",
|
||
token_ids=[],
|
||
cumulative_logprob=None,
|
||
logprobs=None,
|
||
finish_reason="error",
|
||
)
|
||
],
|
||
finished=True,
|
||
prompt_token_ids=prompt_token_ids,
|
||
prompt_logprobs=None,
|
||
)
|
||
return
|
||
|
||
if result.outputs[0].logprobs is not None:
|
||
logprobs = result.outputs[0].logprobs[0]
|
||
all_beams_token_id.extend(list(logprobs.keys()))
|
||
all_beams_logprob.extend(
|
||
[
|
||
current_beam.cum_logprob + obj.logprob
|
||
for obj in logprobs.values()
|
||
]
|
||
)
|
||
|
||
# Handle the token for the end of sentence (EOS)
|
||
all_beams_token_id = np.array(all_beams_token_id)
|
||
all_beams_logprob = np.array(all_beams_logprob)
|
||
|
||
if not ignore_eos:
|
||
# Get the index position of eos token in all generated results
|
||
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
|
||
for idx in eos_idx:
|
||
current_beam = all_beams[idx // logprobs_num]
|
||
result = output[idx // logprobs_num]
|
||
assert result.outputs[0].logprobs is not None
|
||
logprobs_entry = result.outputs[0].logprobs[0]
|
||
completed.append(
|
||
BeamSearchSequence(
|
||
orig_prompt=prompt,
|
||
tokens=current_beam.tokens + [eos_token_id]
|
||
if include_stop_str_in_output
|
||
else current_beam.tokens,
|
||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||
cum_logprob=float(all_beams_logprob[idx]),
|
||
finish_reason="stop",
|
||
stop_reason=eos_token_id,
|
||
)
|
||
)
|
||
# After processing, set the log probability of the eos condition
|
||
# to negative infinity.
|
||
all_beams_logprob[eos_idx] = -np.inf
|
||
|
||
# Processing non-EOS tokens
|
||
# Get indices of the top beam_width probabilities
|
||
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
|
||
:beam_width
|
||
]
|
||
|
||
for idx in topn_idx:
|
||
current_beam = all_beams[idx // logprobs_num]
|
||
result = output[idx // logprobs_num]
|
||
token_id = int(all_beams_token_id[idx])
|
||
assert result.outputs[0].logprobs is not None
|
||
logprobs_entry = result.outputs[0].logprobs[0]
|
||
new_beams.append(
|
||
BeamSearchSequence(
|
||
orig_prompt=prompt,
|
||
tokens=current_beam.tokens + [token_id],
|
||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||
lora_request=current_beam.lora_request,
|
||
cum_logprob=float(all_beams_logprob[idx]),
|
||
)
|
||
)
|
||
|
||
all_beams = new_beams
|
||
|
||
completed.extend(all_beams)
|
||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||
best_beams = sorted_completed[:beam_width]
|
||
|
||
for beam in best_beams:
|
||
if beam.tokens[-1] == eos_token_id and not ignore_eos:
|
||
# Skip the eos token in the text.
|
||
tokens = beam.tokens[tokenized_length:-1]
|
||
else:
|
||
tokens = beam.tokens[tokenized_length:]
|
||
beam.text = tokenizer.decode(tokens)
|
||
|
||
yield RequestOutput(
|
||
request_id=request_id,
|
||
prompt=prompt_text,
|
||
outputs=[
|
||
CompletionOutput(
|
||
text=beam.text, # type: ignore
|
||
cumulative_logprob=beam.cum_logprob,
|
||
token_ids=beam.tokens[tokenized_length:],
|
||
index=i,
|
||
logprobs=beam.logprobs,
|
||
finish_reason=beam.finish_reason
|
||
if beam.finish_reason is not None
|
||
else "length",
|
||
stop_reason=beam.stop_reason,
|
||
)
|
||
for (i, beam) in enumerate(best_beams)
|
||
],
|
||
finished=True,
|
||
prompt_token_ids=prompt_token_ids,
|
||
prompt_logprobs=None,
|
||
)
|
||
|
||
@staticmethod
|
||
def create_error_response(
|
||
message: str | Exception,
|
||
err_type: str = "BadRequestError",
|
||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||
param: str | None = None,
|
||
) -> ErrorResponse:
|
||
return create_error_response(message, err_type, status_code, param)
|
||
|
||
def create_streaming_error_response(
|
||
self,
|
||
message: str | Exception,
|
||
err_type: str = "BadRequestError",
|
||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||
param: str | None = None,
|
||
) -> str:
|
||
json_str = json.dumps(
|
||
self.create_error_response(
|
||
message=message,
|
||
err_type=err_type,
|
||
status_code=status_code,
|
||
param=param,
|
||
).model_dump()
|
||
)
|
||
return json_str
|
||
|
||
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
|
||
"""Raise GenerationError if finish_reason indicates an error."""
|
||
if finish_reason == "error":
|
||
logger.error(
|
||
"Request %s failed with an internal error during generation",
|
||
request_id,
|
||
)
|
||
raise GenerationError("Internal server error")
|
||
|
||
def _convert_generation_error_to_streaming_response(
|
||
self, e: GenerationError
|
||
) -> str:
|
||
"""Convert GenerationError to streaming error response."""
|
||
return self.create_streaming_error_response(
|
||
str(e),
|
||
err_type="InternalServerError",
|
||
status_code=e.status_code,
|
||
)
|
||
|
||
async def _check_model(
|
||
self,
|
||
request: AnyRequest,
|
||
) -> ErrorResponse | None:
|
||
error_response = None
|
||
|
||
if self._is_model_supported(request.model):
|
||
return None
|
||
if request.model in self.models.lora_requests:
|
||
return None
|
||
if (
|
||
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
|
||
and request.model
|
||
and (load_result := await self.models.resolve_lora(request.model))
|
||
):
|
||
if isinstance(load_result, LoRARequest):
|
||
return None
|
||
if (
|
||
isinstance(load_result, ErrorResponse)
|
||
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
|
||
):
|
||
error_response = load_result
|
||
|
||
return error_response or self.create_error_response(
|
||
message=f"The model `{request.model}` does not exist.",
|
||
err_type="NotFoundError",
|
||
status_code=HTTPStatus.NOT_FOUND,
|
||
param="model",
|
||
)
|
||
|
||
def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
|
||
"""Determine if there are any active default multimodal loras."""
|
||
# TODO: Currently this is only enabled for chat completions
|
||
# to be better aligned with only being enabled for .generate
|
||
# when run offline. It would be nice to support additional
|
||
# tasks types in the future.
|
||
message_types = self._get_message_types(request)
|
||
default_mm_loras = set()
|
||
|
||
for lora in self.models.lora_requests.values():
|
||
# Best effort match for default multimodal lora adapters;
|
||
# There is probably a better way to do this, but currently
|
||
# this matches against the set of 'types' in any content lists
|
||
# up until '_', e.g., to match audio_url -> audio
|
||
if lora.lora_name in message_types:
|
||
default_mm_loras.add(lora)
|
||
|
||
# Currently only support default modality specific loras if
|
||
# we have exactly one lora matched on the request.
|
||
if len(default_mm_loras) == 1:
|
||
return default_mm_loras.pop()
|
||
return None
|
||
|
||
def _maybe_get_adapters(
|
||
self,
|
||
request: AnyRequest,
|
||
supports_default_mm_loras: bool = False,
|
||
) -> LoRARequest | None:
|
||
if request.model in self.models.lora_requests:
|
||
return self.models.lora_requests[request.model]
|
||
|
||
# Currently only support default modality specific loras
|
||
# if we have exactly one lora matched on the request.
|
||
if supports_default_mm_loras:
|
||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||
if default_mm_lora is not None:
|
||
return default_mm_lora
|
||
|
||
if self._is_model_supported(request.model):
|
||
return None
|
||
|
||
# if _check_model has been called earlier, this will be unreachable
|
||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||
|
||
def _get_message_types(self, request: AnyRequest) -> set[str]:
|
||
"""Retrieve the set of types from message content dicts up
|
||
until `_`; we use this to match potential multimodal data
|
||
with default per modality loras.
|
||
"""
|
||
message_types: set[str] = set()
|
||
|
||
if not hasattr(request, "messages"):
|
||
return message_types
|
||
|
||
messages = request.messages
|
||
if messages is None or isinstance(messages, (str, bytes)):
|
||
return message_types
|
||
|
||
for message in messages:
|
||
if (
|
||
isinstance(message, dict)
|
||
and "content" in message
|
||
and isinstance(message["content"], list)
|
||
):
|
||
for content_dict in message["content"]:
|
||
if "type" in content_dict:
|
||
message_types.add(content_dict["type"].split("_")[0])
|
||
return message_types
|
||
|
||
def _validate_chat_template(
|
||
self,
|
||
request_chat_template: str | None,
|
||
chat_template_kwargs: dict[str, Any] | None,
|
||
trust_request_chat_template: bool,
|
||
) -> ErrorResponse | None:
|
||
if not trust_request_chat_template and (
|
||
request_chat_template is not None
|
||
or (
|
||
chat_template_kwargs
|
||
and chat_template_kwargs.get("chat_template") is not None
|
||
)
|
||
):
|
||
return self.create_error_response(
|
||
"Chat template is passed with request, but "
|
||
"--trust-request-chat-template is not set. "
|
||
"Refused request with untrusted chat template."
|
||
)
|
||
return None
|
||
|
||
@staticmethod
|
||
def _prepare_extra_chat_template_kwargs(
|
||
request_chat_template_kwargs: dict[str, Any] | None = None,
|
||
default_chat_template_kwargs: dict[str, Any] | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Helper to merge server-default and request-specific chat template kwargs."""
|
||
request_chat_template_kwargs = request_chat_template_kwargs or {}
|
||
if default_chat_template_kwargs is None:
|
||
return request_chat_template_kwargs
|
||
# Apply server defaults first, then request kwargs override.
|
||
return default_chat_template_kwargs | request_chat_template_kwargs
|
||
|
||
def _extract_prompt_components(self, prompt: PromptType | EngineInput):
|
||
return extract_prompt_components(self.model_config, prompt)
|
||
|
||
def _extract_prompt_text(self, prompt: PromptType | EngineInput):
|
||
return self._extract_prompt_components(prompt).text
|
||
|
||
def _extract_prompt_len(self, prompt: EngineInput):
|
||
return extract_prompt_len(self.model_config, prompt)
|
||
|
||
def _log_inputs(
|
||
self,
|
||
request_id: str,
|
||
inputs: PromptType | EngineInput,
|
||
params: SamplingParams | BeamSearchParams | None,
|
||
lora_request: LoRARequest | None,
|
||
) -> None:
|
||
if self.request_logger is None:
|
||
return
|
||
|
||
components = self._extract_prompt_components(inputs)
|
||
|
||
self.request_logger.log_inputs(
|
||
request_id,
|
||
components.text,
|
||
components.token_ids,
|
||
components.embeds,
|
||
params=params,
|
||
lora_request=lora_request,
|
||
)
|
||
|
||
async def _get_trace_headers(
|
||
self,
|
||
headers: Headers,
|
||
) -> Mapping[str, str] | None:
|
||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||
|
||
if is_tracing_enabled:
|
||
return extract_trace_headers(headers)
|
||
|
||
if contains_trace_headers(headers):
|
||
log_tracing_disabled_warning()
|
||
|
||
return None
|
||
|
||
@staticmethod
|
||
def _base_request_id(
|
||
raw_request: Request | None, default: str | None = None
|
||
) -> str | None:
|
||
"""Pulls the request id to use from a header, if provided"""
|
||
if raw_request is not None and (
|
||
(req_id := raw_request.headers.get("X-Request-Id")) is not None
|
||
):
|
||
return req_id
|
||
|
||
return random_uuid() if default is None else default
|
||
|
||
@staticmethod
|
||
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
|
||
"""Pulls the data parallel rank from a header, if provided"""
|
||
if raw_request is None:
|
||
return None
|
||
|
||
rank_str = raw_request.headers.get("X-data-parallel-rank")
|
||
if rank_str is None:
|
||
return None
|
||
|
||
try:
|
||
return int(rank_str)
|
||
except ValueError:
|
||
return None
|
||
|
||
@staticmethod
|
||
def _parse_tool_calls_from_content(
|
||
request: ResponsesRequest | ChatCompletionRequest,
|
||
tokenizer: TokenizerLike | None,
|
||
enable_auto_tools: bool,
|
||
tool_parser_cls: type[ToolParser] | None,
|
||
content: str | None = None,
|
||
) -> tuple[list[FunctionCall] | None, str | None]:
|
||
function_calls = list[FunctionCall]()
|
||
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
|
||
assert content is not None
|
||
# Forced Function Call
|
||
function_calls.append(
|
||
FunctionCall(name=request.tool_choice.name, arguments=content)
|
||
)
|
||
content = None # Clear content since tool is called.
|
||
elif request.tool_choice and isinstance(
|
||
request.tool_choice, ChatCompletionNamedToolChoiceParam
|
||
):
|
||
assert content is not None
|
||
# Forced Function Call
|
||
function_calls.append(
|
||
FunctionCall(name=request.tool_choice.function.name, arguments=content)
|
||
)
|
||
content = None # Clear content since tool is called.
|
||
elif request.tool_choice == "required":
|
||
tool_calls = []
|
||
with contextlib.suppress(ValidationError):
|
||
content = content or ""
|
||
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
|
||
content
|
||
)
|
||
for tool_call in tool_calls:
|
||
function_calls.append(
|
||
FunctionCall(
|
||
name=tool_call.name,
|
||
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
|
||
)
|
||
)
|
||
content = None # Clear content since tool is called.
|
||
elif (
|
||
tool_parser_cls
|
||
and enable_auto_tools
|
||
and (request.tool_choice == "auto" or request.tool_choice is None)
|
||
):
|
||
if tokenizer is None:
|
||
raise ValueError(
|
||
"Tokenizer not available when `skip_tokenizer_init=True`"
|
||
)
|
||
|
||
# Automatic Tool Call Parsing
|
||
try:
|
||
tool_parser = tool_parser_cls(tokenizer, request.tools)
|
||
except RuntimeError as e:
|
||
logger.exception("Error in tool parser creation.")
|
||
raise e
|
||
tool_call_info = tool_parser.extract_tool_calls(
|
||
content if content is not None else "",
|
||
request=request, # type: ignore
|
||
)
|
||
if tool_call_info is not None and tool_call_info.tools_called:
|
||
# extract_tool_calls() returns a list of tool calls.
|
||
function_calls.extend(
|
||
FunctionCall(
|
||
id=tool_call.id,
|
||
name=tool_call.function.name,
|
||
arguments=tool_call.function.arguments,
|
||
)
|
||
for tool_call in tool_call_info.tool_calls
|
||
)
|
||
content = tool_call_info.content
|
||
if content and content.strip() == "":
|
||
content = None
|
||
else:
|
||
# No tool calls.
|
||
return None, content
|
||
|
||
return function_calls, content
|
||
|
||
@staticmethod
|
||
def _get_decoded_token(
|
||
logprob: Logprob,
|
||
token_id: int,
|
||
tokenizer: TokenizerLike | None,
|
||
return_as_token_id: bool = False,
|
||
) -> str:
|
||
if return_as_token_id:
|
||
return f"token_id:{token_id}"
|
||
|
||
if logprob.decoded_token is not None:
|
||
return logprob.decoded_token
|
||
|
||
if tokenizer is None:
|
||
raise ValueError(
|
||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||
)
|
||
|
||
return tokenizer.decode([token_id])
|
||
|
||
def _is_model_supported(self, model_name: str | None) -> bool:
|
||
if not model_name:
|
||
return True
|
||
return self.models.is_base_model(model_name)
|
||
|
||
|
||
def clamp_prompt_logprobs(
|
||
prompt_logprobs: PromptLogprobs | None,
|
||
) -> PromptLogprobs | None:
|
||
if prompt_logprobs is None:
|
||
return prompt_logprobs
|
||
|
||
for logprob_dict in prompt_logprobs:
|
||
if logprob_dict is None:
|
||
continue
|
||
for logprob_values in logprob_dict.values():
|
||
if logprob_values.logprob == float("-inf"):
|
||
logprob_values.logprob = -9999.0
|
||
return prompt_logprobs
|