Push logprob generation to LLMEngine (#3065)

Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-03-04 11:54:06 -08:00
committed by GitHub
parent 76e8a70476
commit 22de45235c
13 changed files with 551 additions and 331 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Union
@@ -11,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ModelCard, ModelList,
ModelPermission)
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
logger = init_logger(__name__)
@@ -83,7 +85,7 @@ class OpenAIServing:
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
@@ -95,10 +97,10 @@ class OpenAIServing:
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
token_logprob = step_top_logprobs[token_id].logprob
else:
token_logprob = None
token = self.tokenizer.convert_ids_to_tokens(token_id)
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
@@ -110,7 +112,7 @@ class OpenAIServing:
if num_output_top_logprobs:
logprobs.top_logprobs.append({
self.tokenizer.convert_ids_to_tokens(i): p
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs
@@ -124,6 +126,19 @@ class OpenAIServing:
type=err_type,
code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model:
return