[V1] Refactor num_computed_tokens logic (#15307)
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -1085,8 +1085,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
for i, generator in self.input_batch.generators.items():
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
@@ -1094,7 +1094,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
@@ -1114,10 +1119,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids, self.input_batch.vocab_size)
|
||||
sampled_token_ids,
|
||||
discard_sampled_tokens_req_indices,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
if not self.use_spec_decode:
|
||||
spec_token_ids = None
|
||||
|
||||
Reference in New Issue
Block a user