Revert "[Core] Performance: Use list[np.ndarray] instead of list[list… (#28773)
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -205,7 +204,7 @@ class RejectionSampler(nn.Module):
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> list[np.ndarray]:
|
||||
) -> list[list[int]]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
@@ -221,7 +220,10 @@ class RejectionSampler(nn.Module):
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)]
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user