[Renderer] Move InputPreprocessor into Renderer (1.5/2) (#34598)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -93,14 +93,14 @@ def _build_renderer(
|
||||
|
||||
|
||||
def _preprocess_prompt(
|
||||
mdoel_config: ModelConfig,
|
||||
model_config: ModelConfig,
|
||||
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
|
||||
):
|
||||
return [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(mdoel_config, prompt)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompt_to_seq(prompt_or_prompts)
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict, assert_never
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import (
|
||||
@@ -200,15 +200,22 @@ class TokenInputs(_InputOptions):
|
||||
prompt_token_ids: list[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
prompt: NotRequired[str]
|
||||
"""The prompt text corresponding to the token IDs, if available."""
|
||||
|
||||
|
||||
def token_inputs(
|
||||
prompt_token_ids: list[int],
|
||||
*,
|
||||
prompt: str | None = None,
|
||||
cache_salt: str | None = None,
|
||||
) -> TokenInputs:
|
||||
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
|
||||
values."""
|
||||
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
@@ -224,15 +231,22 @@ class EmbedsInputs(_InputOptions):
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
prompt: NotRequired[str]
|
||||
"""The prompt text corresponding to the token IDs, if available."""
|
||||
|
||||
|
||||
def embeds_inputs(
|
||||
prompt_embeds: torch.Tensor,
|
||||
*,
|
||||
prompt: str | None = None,
|
||||
cache_salt: str | None = None,
|
||||
) -> EmbedsInputs:
|
||||
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
|
||||
values."""
|
||||
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
@@ -278,10 +292,12 @@ class EncoderDecoderInputs(TypedDict):
|
||||
for encoder-decoder models.
|
||||
"""
|
||||
|
||||
encoder: EncoderInputs
|
||||
type: Literal["enc_dec"]
|
||||
|
||||
encoder_prompt: EncoderInputs
|
||||
"""The inputs for the encoder portion."""
|
||||
|
||||
decoder: DecoderInputs
|
||||
decoder_prompt: DecoderInputs
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
|
||||
@@ -296,3 +312,94 @@ which can be passed to
|
||||
|
||||
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
|
||||
"""The inputs for a single encoder/decoder prompt."""
|
||||
|
||||
|
||||
def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs:
|
||||
if inputs["type"] == "embeds":
|
||||
raise ValueError(
|
||||
"Embedding inputs are not supported for encoder-decoder models"
|
||||
)
|
||||
|
||||
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
|
||||
raise RuntimeError(
|
||||
"You should register an encoder-decoder multi-modal processor "
|
||||
"for encoder-decoder models."
|
||||
)
|
||||
|
||||
return inputs # type: ignore[return-value]
|
||||
|
||||
|
||||
def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs:
|
||||
if inputs["type"] == "embeds":
|
||||
raise ValueError(
|
||||
"Embedding inputs are not supported for encoder-decoder models"
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
decoder_input_ids: list[int],
|
||||
decoder_start_token_id: int,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Prepare `decoder_input_ids` for generation with encoder-decoder models,
|
||||
according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
|
||||
|
||||
Source:
|
||||
https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py
|
||||
"""
|
||||
if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id:
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
|
||||
|
||||
def build_enc_dec_inputs(
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: SingletonInputs | None,
|
||||
decoder_start_token_id: int,
|
||||
) -> EncoderDecoderInputs:
|
||||
enc_inputs = _validate_enc_inputs(encoder_inputs)
|
||||
|
||||
if decoder_inputs is None:
|
||||
dec_inputs: DecoderInputs = enc_inputs
|
||||
else:
|
||||
dec_inputs = _validate_dec_inputs(decoder_inputs)
|
||||
|
||||
enc_inputs_new: EncoderInputs
|
||||
dec_inputs_new: DecoderInputs
|
||||
|
||||
if enc_inputs["type"] == "multimodal":
|
||||
from vllm.multimodal.inputs import mm_inputs
|
||||
|
||||
enc_inputs_new = token_inputs(
|
||||
enc_inputs["encoder_prompt_token_ids"],
|
||||
prompt=enc_inputs.get("encoder_prompt"),
|
||||
)
|
||||
dec_inputs_new = mm_inputs(
|
||||
prompt_token_ids=dec_inputs["prompt_token_ids"],
|
||||
prompt=dec_inputs.get("prompt"),
|
||||
mm_kwargs=enc_inputs["mm_kwargs"],
|
||||
mm_hashes=enc_inputs["mm_hashes"],
|
||||
mm_placeholders=enc_inputs["mm_placeholders"],
|
||||
)
|
||||
elif enc_inputs["type"] == "token":
|
||||
enc_inputs_new = token_inputs(prompt_token_ids=[])
|
||||
dec_inputs_new = dec_inputs
|
||||
else:
|
||||
assert_never(enc_inputs)
|
||||
|
||||
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
|
||||
dec_inputs_new["prompt_token_ids"],
|
||||
decoder_start_token_id,
|
||||
)
|
||||
|
||||
if cache_salt := enc_inputs.get("cache_salt"):
|
||||
dec_inputs_new["cache_salt"] = cache_salt
|
||||
|
||||
return EncoderDecoderInputs(
|
||||
type="enc_dec",
|
||||
encoder_prompt=enc_inputs_new,
|
||||
decoder_prompt=dec_inputs_new,
|
||||
)
|
||||
|
||||
@@ -7,11 +7,7 @@ from .data import ProcessorInputs, SingletonInputs
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[SingletonInputs | None, SingletonInputs]:
|
||||
if "encoder" in inputs and "decoder" in inputs:
|
||||
# NOTE: This passes pyright but not mypy
|
||||
return (
|
||||
inputs["encoder"], # type: ignore[typeddict-item]
|
||||
inputs["decoder"], # type: ignore[typeddict-item]
|
||||
)
|
||||
if inputs["type"] == "enc_dec":
|
||||
return inputs["encoder_prompt"], inputs["decoder_prompt"]
|
||||
|
||||
return None, inputs
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, overload
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs.data import build_enc_dec_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import (
|
||||
@@ -67,54 +68,6 @@ class InputPreprocessor:
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
def get_decoder_start_token_id(self) -> int:
|
||||
"""
|
||||
Obtain the decoder start token id employed by an encoder/decoder
|
||||
model. Raises an error if it is not available.
|
||||
"""
|
||||
dec_start_token_id = getattr(
|
||||
self.model_config.hf_config, "decoder_start_token_id", None
|
||||
)
|
||||
|
||||
if dec_start_token_id is None:
|
||||
logger.warning_once(
|
||||
"Falling back on <BOS> for decoder start token id "
|
||||
"because decoder start token id is not available."
|
||||
)
|
||||
dec_start_token_id = self.renderer.get_bos_token_id()
|
||||
|
||||
if dec_start_token_id is None:
|
||||
raise RuntimeError("Cannot find decoder start token id or <BOS>")
|
||||
|
||||
return dec_start_token_id
|
||||
|
||||
def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
|
||||
Based on:
|
||||
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
|
||||
specifically,
|
||||
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
|
||||
|
||||
Arguments:
|
||||
|
||||
* decoder_input_ids: input token ids to preprocess
|
||||
|
||||
Returns:
|
||||
|
||||
* Processed token list
|
||||
"""
|
||||
decoder_start_token_id = self.get_decoder_start_token_id()
|
||||
|
||||
if (
|
||||
len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id
|
||||
):
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -332,66 +285,6 @@ class InputPreprocessor:
|
||||
|
||||
assert_never(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
|
||||
if inputs["type"] == "embeds":
|
||||
raise ValueError(
|
||||
"Embedding inputs are not supported for encoder-decoder models"
|
||||
)
|
||||
|
||||
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
|
||||
raise RuntimeError(
|
||||
"You should register an encoder-decoder "
|
||||
"multi-modal processor for encoder-decoder models."
|
||||
)
|
||||
|
||||
return inputs # type: ignore[return-value]
|
||||
|
||||
def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
|
||||
if inputs["type"] == "embeds":
|
||||
raise ValueError(
|
||||
"Embedding inputs are not supported for encoder-decoder models"
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def _build_enc_dec_inputs(
|
||||
self,
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: SingletonInputs | None = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
enc_inputs = self._validate_enc_inputs(encoder_inputs)
|
||||
|
||||
if decoder_inputs is None:
|
||||
dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
|
||||
else:
|
||||
dec_inputs = self._validate_dec_inputs(decoder_inputs)
|
||||
|
||||
enc_inputs_new: EncoderInputs
|
||||
dec_inputs_new: DecoderInputs
|
||||
|
||||
if enc_inputs["type"] == "multimodal":
|
||||
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
|
||||
dec_inputs_new = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt_token_ids=dec_inputs["prompt_token_ids"],
|
||||
mm_kwargs=enc_inputs["mm_kwargs"],
|
||||
mm_hashes=enc_inputs["mm_hashes"],
|
||||
mm_placeholders=enc_inputs["mm_placeholders"],
|
||||
)
|
||||
elif enc_inputs["type"] == "token":
|
||||
enc_inputs_new = token_inputs(prompt_token_ids=[])
|
||||
dec_inputs_new = dec_inputs
|
||||
else:
|
||||
assert_never(enc_inputs)
|
||||
|
||||
dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
|
||||
dec_inputs_new["prompt_token_ids"]
|
||||
)
|
||||
if cache_salt := enc_inputs.get("cache_salt"):
|
||||
dec_inputs_new["cache_salt"] = cache_salt
|
||||
|
||||
return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
@@ -417,7 +310,7 @@ class InputPreprocessor:
|
||||
encoder_prompt = prompt["encoder_prompt"]
|
||||
decoder_prompt = prompt["decoder_prompt"]
|
||||
|
||||
return self._build_enc_dec_inputs(
|
||||
return build_enc_dec_inputs(
|
||||
encoder_inputs=self._prompt_to_llm_inputs(
|
||||
encoder_prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
@@ -431,6 +324,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
),
|
||||
decoder_start_token_id=self.renderer.get_dec_start_token_id(),
|
||||
)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
mm_inputs,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
@@ -837,8 +838,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
}
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
return mm_inputs(
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
|
||||
@@ -48,6 +48,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
PlaceholderRange,
|
||||
mm_inputs,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
@@ -222,8 +223,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
|
||||
),
|
||||
)
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
return mm_inputs(
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
|
||||
@@ -33,6 +33,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalInputs,
|
||||
MultiModalUUIDDict,
|
||||
PlaceholderRange,
|
||||
mm_inputs,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
@@ -260,8 +261,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
|
||||
)
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
return mm_inputs(
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import NotRequired, TypeVar
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
@@ -1075,6 +1075,9 @@ class MultiModalInputs(_InputOptions):
|
||||
prompt_token_ids: list[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
prompt: NotRequired[str]
|
||||
"""The prompt text corresponding to the token IDs, if available."""
|
||||
|
||||
mm_kwargs: MultiModalKwargsOptionalItems
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
@@ -1088,6 +1091,31 @@ class MultiModalInputs(_InputOptions):
|
||||
"""
|
||||
|
||||
|
||||
def mm_inputs(
|
||||
prompt_token_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargsOptionalItems,
|
||||
mm_hashes: MultiModalHashes,
|
||||
mm_placeholders: MultiModalPlaceholderDict,
|
||||
*,
|
||||
prompt: str | None = None,
|
||||
cache_salt: str | None = None,
|
||||
) -> MultiModalInputs:
|
||||
inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class MultiModalEncDecInputs(MultiModalInputs):
|
||||
"""
|
||||
Represents the outputs of
|
||||
@@ -1101,3 +1129,31 @@ class MultiModalEncDecInputs(MultiModalInputs):
|
||||
|
||||
encoder_prompt_token_ids: list[int]
|
||||
"""The processed token IDs of the encoder prompt."""
|
||||
|
||||
encoder_prompt: NotRequired[str]
|
||||
"""The prompt text corresponding to the encoder token IDs, if available."""
|
||||
|
||||
|
||||
def mm_enc_dec_inputs(
|
||||
encoder_inputs: MultiModalInputs,
|
||||
decoder_prompt_token_ids: list[int],
|
||||
*,
|
||||
decoder_prompt: str | None = None,
|
||||
) -> MultiModalEncDecInputs:
|
||||
inputs = MultiModalEncDecInputs(
|
||||
type="multimodal",
|
||||
prompt_token_ids=decoder_prompt_token_ids,
|
||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||
mm_kwargs=encoder_inputs["mm_kwargs"],
|
||||
mm_hashes=encoder_inputs["mm_hashes"],
|
||||
mm_placeholders=encoder_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
if decoder_prompt is not None:
|
||||
inputs["prompt"] = decoder_prompt
|
||||
if "prompt" in encoder_inputs:
|
||||
inputs["encoder_prompt"] = encoder_inputs["prompt"]
|
||||
if "cache_salt" in encoder_inputs:
|
||||
inputs["cache_salt"] = encoder_inputs["cache_salt"]
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
@@ -26,7 +26,7 @@ class MediaWithBytes(Generic[_T]):
|
||||
"""
|
||||
|
||||
media: _T
|
||||
original_bytes: bytes
|
||||
original_bytes: bytes = field(repr=False)
|
||||
|
||||
def __array__(self, *args, **kwargs) -> np.ndarray:
|
||||
"""Allow np.array(obj) to return np.array(obj.media)."""
|
||||
|
||||
@@ -34,6 +34,8 @@ from ..inputs import (
|
||||
MultiModalKwargsOptionalItems,
|
||||
MultiModalUUIDDict,
|
||||
PlaceholderRange,
|
||||
mm_enc_dec_inputs,
|
||||
mm_inputs,
|
||||
)
|
||||
from ..parse import (
|
||||
DictEmbeddingItems,
|
||||
@@ -1803,8 +1805,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
}
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
return mm_inputs(
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_info.kwargs,
|
||||
mm_hashes=mm_info.hashes,
|
||||
@@ -1848,12 +1849,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
else:
|
||||
decoder_prompt_ids = decoder_prompt_raw
|
||||
|
||||
mm_inputs = MultiModalEncDecInputs(
|
||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||
**encoder_inputs,
|
||||
return mm_enc_dec_inputs(
|
||||
encoder_inputs,
|
||||
decoder_prompt_ids,
|
||||
)
|
||||
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
|
||||
return mm_inputs
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -153,6 +153,27 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
def get_dec_start_token_id(self) -> int:
|
||||
"""
|
||||
Obtain the decoder start token id employed by an encoder/decoder model,
|
||||
raising an error if it is not available.
|
||||
"""
|
||||
dec_start_token_id = getattr(
|
||||
self.model_config.hf_config, "decoder_start_token_id", None
|
||||
)
|
||||
|
||||
if dec_start_token_id is None:
|
||||
logger.warning_once(
|
||||
"Falling back on <BOS> for decoder start token id "
|
||||
"because decoder start token id is not available."
|
||||
)
|
||||
dec_start_token_id = self.get_bos_token_id()
|
||||
|
||||
if dec_start_token_id is None:
|
||||
raise RuntimeError("Cannot find decoder start token id or <BOS>")
|
||||
|
||||
return dec_start_token_id
|
||||
|
||||
@cached_property
|
||||
def default_cmpl_tok_params(self) -> TokenizeParams:
|
||||
mm_processor = self.mm_processor
|
||||
|
||||
Reference in New Issue
Block a user