Fix cu_num_generated_tokens slicing logic in LogprobsLists.slice() method (#28214)
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
This commit is contained in:
@@ -30,16 +30,23 @@ class LogprobsLists(NamedTuple):
|
||||
if self.cu_num_generated_tokens:
|
||||
start = self.cu_num_generated_tokens[start_req_idx]
|
||||
end = self.cu_num_generated_tokens[end_req_idx]
|
||||
# Recompute cumulative array starting from 0
|
||||
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
|
||||
sliced_cu_num_generated_tokens = [
|
||||
cu_num - cu_num_offset
|
||||
for cu_num in self.cu_num_generated_tokens[
|
||||
start_req_idx : end_req_idx + 1
|
||||
]
|
||||
]
|
||||
else:
|
||||
start = start_req_idx
|
||||
end = end_req_idx
|
||||
sliced_cu_num_generated_tokens = None
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
|
||||
if self.cu_num_generated_tokens
|
||||
else None,
|
||||
sliced_cu_num_generated_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user