[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
@@ -3,8 +3,9 @@ import copy
|
||||
import enum
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -916,6 +917,21 @@ def get_all_seq_ids(
|
||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||
|
||||
|
||||
def get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> Tuple[List[int], Dict[str, Set[int]]]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
seq_ids: List[int] = []
|
||||
request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
for sg in seq_group_metadata_list:
|
||||
for seq_id in sg.seq_data:
|
||||
seq_ids.append(seq_id)
|
||||
request_id_seq_ids_mapping[sg.request_id].add(seq_id)
|
||||
return seq_ids, request_id_seq_ids_mapping
|
||||
|
||||
|
||||
class HiddenStates:
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
|
||||
Reference in New Issue
Block a user