[Frontend] API support for beam search for MQLLMEngine (#9117)

This commit is contained in:
Brendan Wong
2024-10-07 22:51:43 -07:00
committed by GitHub
parent e1faa2a598
commit 8c746226c9
8 changed files with 215 additions and 106 deletions

View File

@@ -1,12 +1,13 @@
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)
from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
@@ -28,43 +29,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.