Fix cu_num_generated_tokens slicing logic in LogprobsLists.slice() method (#28214)

Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
This commit is contained in:
usberkeley
2025-11-10 03:11:46 +08:00
committed by GitHub
parent 636efd10a5
commit 4a8d6bd168
2 changed files with 111 additions and 3 deletions

View File

@@ -30,16 +30,23 @@ class LogprobsLists(NamedTuple):
if self.cu_num_generated_tokens:
start = self.cu_num_generated_tokens[start_req_idx]
end = self.cu_num_generated_tokens[end_req_idx]
# Recompute cumulative array starting from 0
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
sliced_cu_num_generated_tokens = [
cu_num - cu_num_offset
for cu_num in self.cu_num_generated_tokens[
start_req_idx : end_req_idx + 1
]
]
else:
start = start_req_idx
end = end_req_idx
sliced_cu_num_generated_tokens = None
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
if self.cu_num_generated_tokens
else None,
sliced_cu_num_generated_tokens,
)