[Core] Asynchronous Output Processor (#7049)

Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
Megha Agarwal
2024-08-26 20:53:20 -07:00
committed by GitHub
parent 015e6cc252
commit 2eedede875
21 changed files with 652 additions and 214 deletions

View File

@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Optional, Set, Tuple, Union, cast)
import msgspec
import torch
@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
@@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async postprocessor
output_proc_callback_fn: Optional[Callable] = None
@property
def is_first_multi_step(self) -> bool:
@@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)
if self.last_sampled_token_ids is not None else None,
output_proc_callback_fn=self.output_proc_callback_fn)