[Speculative decoding] Support target-model logprobs (#4378)
This commit is contained in:
@@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: Iterable[Optional[torch.Tensor]],
|
||||
logprobs: Iterable[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
@@ -222,6 +223,7 @@ def create_sampler_output_list(
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
logprobs=logprobs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user