[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)

This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
SangBin Cho
2024-05-15 14:00:10 +09:00
committed by GitHub
parent 8a7cc254a0
commit 65bf2ac165
18 changed files with 781 additions and 734 deletions

View File

@@ -293,21 +293,30 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data={
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
},
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
def _split_scoring_output(

View File

@@ -114,6 +114,7 @@ class MultiStepWorker(Worker):
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]