[Doc] Add docs for prompt replacement (#12318)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -218,7 +218,7 @@ class UltravoxMultiModalProcessor(
|
|||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="audio",
|
modality="audio",
|
||||||
target='<|audio|>',
|
target="<|audio|>",
|
||||||
replacement=get_replacement_ultravox,
|
replacement=get_replacement_ultravox,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -29,41 +29,101 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_S = TypeVar("_S", str, list[int])
|
_S = TypeVar("_S", str, list[int])
|
||||||
_PromptSeq = Union[str, list[int]]
|
|
||||||
|
PromptSeq = Union[str, list[int]]
|
||||||
|
"""A token sequence (list of token IDs) or text."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptReplacementDetails:
|
class PromptReplacementDetails:
|
||||||
full: _PromptSeq
|
"""Details about the replacement token sequence or text."""
|
||||||
|
|
||||||
|
full: PromptSeq
|
||||||
"""The full replacement."""
|
"""The full replacement."""
|
||||||
|
|
||||||
features: _PromptSeq
|
features: PromptSeq
|
||||||
"""
|
"""
|
||||||
The part of the replacement that corresponds to placeholder feature tokens.
|
The part of the replacement that corresponds to feature placeholders;
|
||||||
|
this will be replaced by the output of the vision encoder during model
|
||||||
|
inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_seq(seq: _PromptSeq) -> "PromptReplacementDetails":
|
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
|
||||||
return PromptReplacementDetails(full=seq, features=seq)
|
return PromptReplacementDetails(full=seq, features=seq)
|
||||||
|
|
||||||
|
|
||||||
_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]
|
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
|
||||||
|
"""
|
||||||
|
The replacement token sequence or text.
|
||||||
|
|
||||||
|
If only part of the replacement corresponds to feature placeholders, you can
|
||||||
|
use :class:`PromptReplacementDetails` to specify which part.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptReplacement:
|
class PromptReplacement:
|
||||||
"""
|
"""
|
||||||
Defines how to replace portions of an input prompt with placeholder tokens.
|
Defines how to replace portions of an input prompt with placeholder tokens.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
For each image, replace one ``<image>`` input placeholder in the prompt
|
||||||
|
with a number of ``<image>`` feature placeholders
|
||||||
|
equal to the feature size of the vision encoder:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target="<image>",
|
||||||
|
replacement="<image>" * image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
As above, but further pad the feature placeholders with ``<image_bos>``
|
||||||
|
and `<image_eos>``, which are not supposed to be passed to the vision
|
||||||
|
encoder:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target="<image>",
|
||||||
|
replacement=PromptReplacementDetails(
|
||||||
|
full="".join([
|
||||||
|
"<image_bos>",
|
||||||
|
"<image>" * image_feature_size,
|
||||||
|
"<image_eos>",
|
||||||
|
]),
|
||||||
|
features="<image>" * image_feature_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
To avoid unnecessary tokenization during prompt replacement,
|
||||||
|
we recommended passing token sequences instead of text:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=[image_token_id],
|
||||||
|
replacement=PromptReplacementDetails(
|
||||||
|
full=([image_bos_id] + [image_token_id] * image_feature_size
|
||||||
|
+ [image_eos_id]),
|
||||||
|
features=[image_token_id] * image_feature_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
modality: str
|
modality: str
|
||||||
"""The modality for which the replacement is made."""
|
"""The modality for which the replacement is made."""
|
||||||
|
|
||||||
target: _PromptSeq
|
target: PromptSeq
|
||||||
"""The token sequence (or text) to find and replace."""
|
"""The token sequence (or text) to find and replace."""
|
||||||
|
|
||||||
replacement: Union[Callable[[int], _PromptRepl],
|
replacement: Union[Callable[[int], PromptRepl],
|
||||||
_PromptRepl] = field(repr=False)
|
PromptRepl] = field(repr=False)
|
||||||
"""
|
"""
|
||||||
Given the index of the processed item within :attr:`modality`,
|
Given the index of the processed item within :attr:`modality`,
|
||||||
output the replacement token sequence (or text).
|
output the replacement token sequence (or text).
|
||||||
@@ -126,6 +186,10 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _BoundPromptSequence:
|
class _BoundPromptSequence:
|
||||||
|
"""
|
||||||
|
A :data:`_PromptSeq` bound to a tokenizer to automatically
|
||||||
|
convert between token sequence and text representations.
|
||||||
|
"""
|
||||||
tokenizer: AnyTokenizer = field(repr=False)
|
tokenizer: AnyTokenizer = field(repr=False)
|
||||||
|
|
||||||
_text: Optional[str]
|
_text: Optional[str]
|
||||||
@@ -134,7 +198,7 @@ class _BoundPromptSequence:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_seq(
|
def from_seq(
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
seq: _PromptSeq,
|
seq: PromptSeq,
|
||||||
) -> "_BoundPromptSequence":
|
) -> "_BoundPromptSequence":
|
||||||
return _BoundPromptSequence(
|
return _BoundPromptSequence(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -180,9 +244,9 @@ class BoundPromptReplacement:
|
|||||||
tokenizer: AnyTokenizer = field(repr=False)
|
tokenizer: AnyTokenizer = field(repr=False)
|
||||||
modality: str
|
modality: str
|
||||||
|
|
||||||
_target: _PromptSeq
|
_target: PromptSeq
|
||||||
_replacement: Union[Callable[[int], _PromptRepl],
|
_replacement: Union[Callable[[int], PromptRepl],
|
||||||
_PromptRepl] = field(repr=False)
|
PromptRepl] = field(repr=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
|
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
|
||||||
@@ -350,7 +414,7 @@ def find_text_matches(
|
|||||||
|
|
||||||
|
|
||||||
def _resolve_matches(
|
def _resolve_matches(
|
||||||
prompt: _PromptSeq,
|
prompt: PromptSeq,
|
||||||
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
||||||
) -> list[_PromptReplacementMatch]:
|
) -> list[_PromptReplacementMatch]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user