[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)

This commit is contained in:
SangBin Cho
2024-04-26 22:02:02 +09:00
committed by GitHub
parent a88081bf76
commit 603ad84815
18 changed files with 859 additions and 630 deletions

View File

@@ -83,30 +83,27 @@ def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
logits_processed = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
token_ids = seq_group.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0]
assert logits_processed == logits.shape[0]
return logits