[Misc] Misc code simplifications (#26450)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
) -> torch.Tensor:
|
||||
has_penalties = not sampling_metadata.no_penalties
|
||||
any_penalties_or_bad_words = (
|
||||
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
||||
sampling_metadata.bad_words_token_ids or has_penalties
|
||||
)
|
||||
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
if any_penalties_or_bad_words:
|
||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||
sampling_metadata.output_token_ids,
|
||||
output_token_ids,
|
||||
sampling_metadata.spec_token_ids,
|
||||
)
|
||||
|
||||
# Calculate indices of target logits.
|
||||
if (
|
||||
sampling_metadata.allowed_token_ids_mask is not None
|
||||
or not sampling_metadata.no_penalties
|
||||
):
|
||||
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
|
||||
num_requests = len(sampling_metadata.output_token_ids)
|
||||
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||
original_indices = torch.arange(num_requests, device="cpu")
|
||||
@@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
|
||||
logits.masked_fill_(token_mask, float("-inf"))
|
||||
|
||||
# Apply bad words exclusion.
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words_with_drafts(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
output_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
@@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
|
||||
)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _combine_outputs_with_spec_tokens(
|
||||
self,
|
||||
output_token_ids: list[list[int]],
|
||||
spec_token_ids: Optional[list[list[int]]] = None,
|
||||
) -> list[list[int]]:
|
||||
|
||||
Reference in New Issue
Block a user