Implement presence and frequency penalties (#95)
This commit is contained in:
@@ -3,11 +3,12 @@ import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from cacheflow.core.block_manager import BlockSpaceManager
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.core.policy import PolicyFactory
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceOutputs, SequenceStatus)
|
||||
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -246,27 +247,17 @@ class Scheduler:
|
||||
group_id = seq_group.group_id
|
||||
is_prompt = group_id in prompt_group_ids
|
||||
|
||||
input_tokens: Dict[int, List[int]] = {}
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
seq_data: Dict[int, List[SequenceData]] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
seq_data[seq_id] = seq.data
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
if is_prompt:
|
||||
input_tokens[seq_id] = seq.get_token_ids()
|
||||
else:
|
||||
input_tokens[seq_id] = [seq.get_last_token_id()]
|
||||
seq_logprobs[seq_id] = seq.cumulative_logprobs
|
||||
# NOTE(woosuk): Sequences in the same group have the same
|
||||
# sequence length
|
||||
seq_len = seq.get_len()
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
group_id=group_id,
|
||||
is_prompt=is_prompt,
|
||||
input_tokens=input_tokens,
|
||||
context_len=seq_len,
|
||||
seq_logprobs=seq_logprobs,
|
||||
seq_data=seq_data,
|
||||
sampling_params=self.sampling_params[group_id],
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user