[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
@@ -16,13 +17,13 @@ from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@@ -798,6 +799,14 @@ class LLM:
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
if guided_options is not None:
|
||||
warnings.warn(
|
||||
"guided_options_request is deprecated, use "
|
||||
"SamplingParams.guided_decoding instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
@@ -813,7 +822,7 @@ class LLM:
|
||||
|
||||
for sp in params if isinstance(params, list) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
self._add_guided_processor(sp, guided_options)
|
||||
self._add_guided_params(sp, guided_options)
|
||||
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
@@ -847,22 +856,25 @@ class LLM:
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def _add_guided_processor(
|
||||
def _add_guided_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None):
|
||||
if guided_options:
|
||||
if guided_options.guided_decoding_backend is None:
|
||||
decoding_config = self.llm_engine.get_decoding_config()
|
||||
guided_options.guided_decoding_backend = (
|
||||
decoding_config.guided_decoding_backend)
|
||||
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
|
||||
guided_options.guided_decoding_backend, guided_options,
|
||||
self.get_tokenizer())
|
||||
if guided_logits_processor:
|
||||
if params.logits_processors is None:
|
||||
params.logits_processors = []
|
||||
params.logits_processors.append(guided_logits_processor)
|
||||
if guided_options is None:
|
||||
return params
|
||||
|
||||
if params.guided_decoding is not None:
|
||||
raise ValueError("Cannot set both guided_options_request and"
|
||||
"params.guided_decoding.")
|
||||
|
||||
params.guided_decoding = GuidedDecodingParams(
|
||||
json=guided_options.guided_json,
|
||||
regex=guided_options.guided_regex,
|
||||
choice=guided_options.guided_choice,
|
||||
grammar=guided_options.guided_grammar,
|
||||
json_object=guided_options.guided_json_object,
|
||||
backend=guided_options.guided_decoding_backend,
|
||||
whitespace_pattern=guided_options.guided_whitespace_pattern)
|
||||
return params
|
||||
|
||||
def _run_engine(
|
||||
|
||||
Reference in New Issue
Block a user