[V1][Minor] Simplify rejection sampler's parse_output (#15741)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-03-29 09:25:17 -07:00
committed by GitHub
parent c67abd614f
commit 2bc4be4e32
2 changed files with 3 additions and 11 deletions

View File

@@ -1121,16 +1121,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
discard_sampled_tokens_req_indices,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.use_spec_decode:
spec_token_ids = None