[Core] Use array to speedup padding (#6779)
This commit is contained in:
@@ -3,6 +3,7 @@ import copy
|
||||
import enum
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
@@ -119,10 +120,10 @@ class SequenceData:
|
||||
prompt_token_ids: List[int],
|
||||
output_token_ids: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
self._prompt_token_ids: List[int] = list(prompt_token_ids)
|
||||
self._prompt_token_ids = array('l', prompt_token_ids)
|
||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
|
||||
self._output_token_ids: List[int] = (
|
||||
list(output_token_ids) if output_token_ids is not None else [])
|
||||
self._output_token_ids = array(
|
||||
'l', output_token_ids if output_token_ids is not None else [])
|
||||
|
||||
self.cumulative_logprob = 0.0
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
@@ -132,8 +133,8 @@ class SequenceData:
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
def _update_cached_all_tokens(self):
|
||||
self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> Tuple[int, ...]:
|
||||
@@ -141,19 +142,27 @@ class SequenceData:
|
||||
|
||||
@prompt_token_ids.setter
|
||||
def prompt_token_ids(self, new_prompt_token_ids) -> None:
|
||||
self._prompt_token_ids = list(new_prompt_token_ids)
|
||||
self._prompt_token_ids = array('l', new_prompt_token_ids)
|
||||
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
@property
|
||||
def prompt_token_ids_array(self) -> array:
|
||||
return self._prompt_token_ids
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> Tuple[int, ...]:
|
||||
return tuple(self._output_token_ids)
|
||||
|
||||
@output_token_ids.setter
|
||||
def output_token_ids(self, new_output_token_ids) -> None:
|
||||
self._output_token_ids = list(new_output_token_ids)
|
||||
self._output_token_ids = array('l', new_output_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
@property
|
||||
def output_token_ids_array(self) -> array:
|
||||
return self._output_token_ids
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._cached_all_token_ids.append(token_id)
|
||||
|
||||
Reference in New Issue
Block a user