[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)

This commit is contained in:
sroy745
2024-07-10 16:02:47 -07:00
committed by GitHub
parent 44cc76610d
commit ae151d73be
14 changed files with 645 additions and 80 deletions

View File

@@ -3,8 +3,9 @@ import copy
import enum
import math
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
import torch
@@ -916,6 +917,21 @@ def get_all_seq_ids(
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
def get_all_seq_ids_and_request_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
seq_ids: List[int] = []
request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
for sg in seq_group_metadata_list:
for seq_id in sg.seq_data:
seq_ids.append(seq_id)
request_id_seq_ids_mapping[sg.request_id].add(seq_id)
return seq_ids, request_id_seq_ids_mapping
class HiddenStates:
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from