[Frontend] API support for beam search for MQLLMEngine (#9117)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import itertools
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
Union, cast, overload)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
@@ -28,43 +29,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
|
||||
is_list_of)
|
||||
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
# The tokens includes the prompt.
|
||||
tokens: List[int]
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
sequences: List[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: List[int]):
|
||||
self.beams: List[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens)
|
||||
]
|
||||
self.completed: List[BeamSearchSequence] = []
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user