diff --git a/train_lora.py b/train_lora.py index 2ffce98..00d989d 100644 --- a/train_lora.py +++ b/train_lora.py @@ -81,53 +81,7 @@ def tokenize_for_training(sample, tokenizer, max_length=4096): ("<|tool_response_end|>", 128014), ] - for text_str, special_id in TOKEN_REPLACEMENTS: - # Find all occurrences in the text - search_pos = 0 - while True: - idx = text.find(text_str, search_pos) - if idx == -1: - break - - # Find the token span that covers this text range - char_start = idx - char_end = idx + len(text_str) - - tok_start = None - tok_end = None - for ti, (cs, ce) in enumerate(enc.get("offset_mapping", [])): - if cs >= char_start and tok_start is None: - tok_start = ti - if ce >= char_end: - tok_end = ti + 1 - break - - if tok_start is not None and tok_end is not None: - # Replace the multi-token sequence with the single special token - input_ids[tok_start] = special_id - # Remove the extra tokens - remove_count = tok_end - tok_start - 1 - if remove_count > 0: - del input_ids[tok_start + 1:tok_start + 1 + remove_count] - del attention_mask[tok_start + 1:tok_start + 1 + remove_count] - del labels[tok_start + 1:tok_start + 1 + remove_count] - # Adjust labels list length - # We also need to re-adjust subsequent replacement offsets, - # so we recalculate from the modified lists - # Easiest: rebuild from scratch for each replacement - # But since we process sequentially, just re-encode - # Actually, let's just re-tokenize the modified text instead - - search_pos = idx + len(text_str) - - # The above in-place replacement is tricky with shifting indices. - # Simpler approach: replace in the text first, then re-tokenize. - # But special tokens can't be in the text for regular tokenization. - # So let's use a two-pass approach: tokenize the text without the special - # token strings, then insert the special token IDs at the right positions. - - # Actually, the simplest correct approach: rebuild input_ids by encoding - # segments between the special token markers. + # Rebuild input_ids with special tokens as single token IDs input_ids = _tokenize_with_special_tokens(text, tokenizer, TOKEN_REPLACEMENTS, max_length) attention_mask = [1] * len(input_ids) labels = [-100] * len(input_ids)