[Core] Asynchronous Output Processor (#7049)
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Union, cast)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
|
||||
Optional, Set, Tuple, Union, cast)
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@@ -474,11 +474,8 @@ class Sequence:
|
||||
"""Reset the sequence states for recomputation."""
|
||||
self.data.reset_state_for_recompute()
|
||||
|
||||
def append_token_id(
|
||||
self,
|
||||
token_id: int,
|
||||
logprobs: Dict[int, Logprob],
|
||||
) -> None:
|
||||
def append_token_id(self, token_id: int, logprobs: Dict[int,
|
||||
Logprob]) -> None:
|
||||
assert token_id in logprobs
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
||||
@@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async postprocessor
|
||||
output_proc_callback_fn: Optional[Callable] = None
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
@@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
|
||||
num_steps=self.num_steps,
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None)
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
output_proc_callback_fn=self.output_proc_callback_fn)
|
||||
|
||||
Reference in New Issue
Block a user