[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)
This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
@@ -293,21 +293,30 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
|
||||
new_seq_data_dict = {
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
),
|
||||
}
|
||||
# This is a hack. Technically, spec decoding should compute
|
||||
# num_lookahead slots at one shot, but instead, it expands the batch
|
||||
# and evaluate one by one right now. context_len is seq_len - 1 because
|
||||
# the kv cache is filled by a previous batch in the batch expansion.
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data={
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
),
|
||||
},
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
|
||||
@@ -114,6 +114,7 @@ class MultiStepWorker(Worker):
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
|
||||
Reference in New Issue
Block a user