[Core] Add engine option to return only deltas or final output (#7381)
This commit is contained in:
@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
|
||||
Optional, Set, Tuple, Union, cast)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union, cast
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@@ -407,6 +408,10 @@ class Sequence:
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
|
||||
# These are used to keep track of delta outputs
|
||||
self._last_token_ids_offset: int = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
# Used for incremental detokenization
|
||||
self.prefix_offset = 0
|
||||
self.read_offset = 0
|
||||
@@ -462,11 +467,35 @@ class Sequence:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_output_text_to_return(self, buffer_length: int):
|
||||
def get_output_text_to_return(self, buffer_length: int,
|
||||
delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
truncate = buffer_length and not self.is_finished()
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
if not delta:
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
def get_output_token_ids_to_return(self,
|
||||
delta: bool) -> GenericSequence[int]:
|
||||
"""If delta is True, only new tokens since the last call to
|
||||
this method are returned"""
|
||||
if not delta:
|
||||
return self.get_output_token_ids()
|
||||
length = self.get_output_len()
|
||||
last_offset = self._last_token_ids_offset
|
||||
if last_offset < length:
|
||||
self._last_token_ids_offset = length
|
||||
return self.data._output_token_ids[last_offset:]
|
||||
return ()
|
||||
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# TODO This can produce incorrect hash when block size > prompt size
|
||||
|
||||
Reference in New Issue
Block a user