[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_token_ids = sampling_metadata.spec_token_ids
|
||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||
batch_size = len(spec_token_ids)
|
||||
@@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
|
||||
Reference in New Issue
Block a user