[V1] Refactor num_computed_tokens logic (#15307)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Cody Yu
2025-03-26 21:54:36 -07:00
committed by GitHub
parent fb22be5817
commit 54aa619459
5 changed files with 106 additions and 57 deletions

View File

@@ -1085,8 +1085,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, generator in self.input_batch.generators.items():
req_id = self.input_batch.req_ids[i]
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
@@ -1094,7 +1094,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
@@ -1114,10 +1119,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size)
sampled_token_ids,
discard_sampled_tokens_req_indices,
self.input_batch.vocab_size,
)
if not self.use_spec_decode:
spec_token_ids = None