Remove broken in-place replacement code, use _tokenize_with_special_tokens only

This commit is contained in:
Jinx
2026-04-10 17:35:09 +00:00
parent ca50973065
commit 278d87286a

View File

@@ -81,53 +81,7 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
("<|tool_response_end|>", 128014),
]
for text_str, special_id in TOKEN_REPLACEMENTS:
# Find all occurrences in the text
search_pos = 0
while True:
idx = text.find(text_str, search_pos)
if idx == -1:
break
# Find the token span that covers this text range
char_start = idx
char_end = idx + len(text_str)
tok_start = None
tok_end = None
for ti, (cs, ce) in enumerate(enc.get("offset_mapping", [])):
if cs >= char_start and tok_start is None:
tok_start = ti
if ce >= char_end:
tok_end = ti + 1
break
if tok_start is not None and tok_end is not None:
# Replace the multi-token sequence with the single special token
input_ids[tok_start] = special_id
# Remove the extra tokens
remove_count = tok_end - tok_start - 1
if remove_count > 0:
del input_ids[tok_start + 1:tok_start + 1 + remove_count]
del attention_mask[tok_start + 1:tok_start + 1 + remove_count]
del labels[tok_start + 1:tok_start + 1 + remove_count]
# Adjust labels list length
# We also need to re-adjust subsequent replacement offsets,
# so we recalculate from the modified lists
# Easiest: rebuild from scratch for each replacement
# But since we process sequentially, just re-encode
# Actually, let's just re-tokenize the modified text instead
search_pos = idx + len(text_str)
# The above in-place replacement is tricky with shifting indices.
# Simpler approach: replace in the text first, then re-tokenize.
# But special tokens can't be in the text for regular tokenization.
# So let's use a two-pass approach: tokenize the text without the special
# token strings, then insert the special token IDs at the right positions.
# Actually, the simplest correct approach: rebuild input_ids by encoding
# segments between the special token markers.
# Rebuild input_ids with special tokens as single token IDs
input_ids = _tokenize_with_special_tokens(text, tokenizer, TOKEN_REPLACEMENTS, max_length)
attention_mask = [1] * len(input_ids)
labels = [-100] * len(input_ids)