Push logprob generation to LLMEngine (#3065)
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -11,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -83,7 +85,7 @@ class OpenAIServing:
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
||||
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
@@ -95,10 +97,10 @@ class OpenAIServing:
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
else:
|
||||
token_logprob = None
|
||||
token = self.tokenizer.convert_ids_to_tokens(token_id)
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
@@ -110,7 +112,7 @@ class OpenAIServing:
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
self.tokenizer.convert_ids_to_tokens(i): p
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
@@ -124,6 +126,19 @@ class OpenAIServing:
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
||||
json_str = json.dumps({
|
||||
"error":
|
||||
self.create_error_response(message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code).model_dump()
|
||||
})
|
||||
return json_str
|
||||
|
||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||
if request.model == self.served_model:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user