[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)

This commit is contained in:
Nick Hill
2024-10-18 08:19:53 +01:00
committed by GitHub
parent 944dd8edaf
commit 1ffc8a7362
4 changed files with 26 additions and 14 deletions

View File

@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from vllm.sequence import Logprob
@dataclass
@@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
@@ -28,7 +31,7 @@ class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []