[Bugfix] Fix logits processor when prompt_logprobs is not None (#3899)
This commit is contained in:
@@ -86,8 +86,16 @@ def _apply_logits_processors(
|
||||
) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
logits_processors = sampling_params.logits_processors
|
||||
# handle prompt_logprobs by skipping rows in logits added for
|
||||
# the prompt tokens (prompt logprobs are not processed)
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
assert len(seq_ids) == 1
|
||||
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
|
||||
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
@@ -100,5 +108,6 @@ def _apply_logits_processors(
|
||||
else:
|
||||
logits_row_idx += len(seq_ids)
|
||||
if found_logits_processors:
|
||||
# verifies that no rows in logits were missed unexpectedly
|
||||
assert logits_row_idx == logits.shape[0]
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user