[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -11,6 +13,7 @@ class BeamSearchSequence:
|
||||
"""
|
||||
# The tokens includes the prompt.
|
||||
tokens: List[int]
|
||||
logprobs: List[Dict[int, Logprob]]
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
|
||||
@@ -28,7 +31,7 @@ class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: List[int]):
|
||||
self.beams: List[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens)
|
||||
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
||||
]
|
||||
self.completed: List[BeamSearchSequence] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user