Implement presence and frequency penalties (#95)

This commit is contained in:
Woosuk Kwon
2023-05-10 23:39:12 -07:00
committed by GitHub
parent 9f88db35da
commit 55f8b0a5de
9 changed files with 215 additions and 82 deletions

View File

@@ -3,11 +3,12 @@ import time
from typing import Dict, List, Optional, Tuple
from cacheflow.core.block_manager import BlockSpaceManager
from cacheflow.logger import init_logger
from cacheflow.core.policy import PolicyFactory
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceOutputs, SequenceStatus)
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
logger = init_logger(__name__)
@@ -246,27 +247,17 @@ class Scheduler:
group_id = seq_group.group_id
is_prompt = group_id in prompt_group_ids
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
input_tokens[seq_id] = seq.get_token_ids()
else:
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
seq_group_metadata = SequenceGroupMetadata(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
seq_data=seq_data,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)