[Bugfix] Fix Random Dataset Prefix Length Inaccuracy (#33907)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -380,7 +380,7 @@ def gen_prompt_decode_to_target_len(
|
||||
max_retry: int = 10,
|
||||
add_special_tokens: bool = False,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> tuple[str, list[int]]:
|
||||
) -> tuple[str, list[int], int]:
|
||||
"""
|
||||
Ensure decoded-then-encoded prompt length matches the target token length.
|
||||
|
||||
@@ -392,7 +392,9 @@ def gen_prompt_decode_to_target_len(
|
||||
[6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
|
||||
Returns a tuple of the final prompt string and the adjusted token sequence.
|
||||
Returns a tuple of the final prompt string, the adjusted token sequence,
|
||||
and the token mismatch (final_len - target_token_len) if the retry budget
|
||||
is exhausted.
|
||||
"""
|
||||
remain_num_try = max_retry
|
||||
token_mismatch = 0
|
||||
@@ -499,7 +501,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
|
||||
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
|
||||
prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len)
|
||||
|
||||
requests = []
|
||||
token_mismatch_total = 0
|
||||
@@ -554,19 +556,36 @@ class RandomDataset(BenchmarkDataset):
|
||||
|
||||
def get_prefix(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
allowed_tokens: np.ndarray,
|
||||
prefix_len: int,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the prefix for the dataset.
|
||||
"""
|
||||
return (
|
||||
allowed_tokens[
|
||||
self._rng.integers(0, len(allowed_tokens), size=prefix_len)
|
||||
].tolist()
|
||||
if prefix_len > 0
|
||||
else []
|
||||
if prefix_len <= 0:
|
||||
return []
|
||||
|
||||
prefix_tokens = allowed_tokens[
|
||||
self._rng.integers(0, len(allowed_tokens), size=prefix_len)
|
||||
].tolist()
|
||||
_, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len(
|
||||
tokenizer=tokenizer,
|
||||
token_sequence=prefix_tokens,
|
||||
target_token_len=prefix_len,
|
||||
add_special_tokens=False,
|
||||
rng=self._rng,
|
||||
)
|
||||
if token_mismatch != 0:
|
||||
sign = "more" if token_mismatch > 0 else "fewer"
|
||||
logger.warning(
|
||||
"Prefix tokenization produced %d %s tokens than expected "
|
||||
"after decoding and re-encoding. This is expected due to "
|
||||
"the imperfect nature of the sampling procedure",
|
||||
abs(token_mismatch),
|
||||
sign,
|
||||
)
|
||||
return adjusted_tokens
|
||||
|
||||
def get_sampling_params(
|
||||
self,
|
||||
@@ -1128,7 +1147,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
"Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size
|
||||
)
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
|
||||
prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len)
|
||||
# Add synthetic multimodal items to each request
|
||||
mm_requests = []
|
||||
token_mismatch_total = 0
|
||||
|
||||
Reference in New Issue
Block a user