[V1][Minor] Simplify rejection sampler's parse_output (#15741)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -1121,16 +1121,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
discard_sampled_tokens_req_indices,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
if not self.use_spec_decode:
|
||||
spec_token_ids = None
|
||||
|
||||
Reference in New Issue
Block a user