[V1] Optimize handling of sampling metadata and req_ids list (#13244)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-18 12:15:33 -08:00
committed by GitHub
parent a4d577b379
commit 30172b4947
15 changed files with 254 additions and 297 deletions

View File

@@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
assert sampling_metadata.spec_token_ids is not None
spec_token_ids = sampling_metadata.spec_token_ids
max_spec_len = max(len(s) for s in spec_token_ids)
batch_size = len(spec_token_ids)
@@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
assert sampling_metadata.spec_token_ids is not None
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]