[Speculative decoding] Support target-model logprobs (#4378)

This commit is contained in:
Cade Daniel
2024-05-03 15:52:01 -07:00
committed by GitHub
parent 43c413ec57
commit ab50275111
15 changed files with 727 additions and 86 deletions

View File

@@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
@@ -222,6 +223,7 @@ def create_sampler_output_list(
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]