Fix tokenization: replace text token sequences with actual special token IDs

This commit is contained in:
Jinx
2026-04-10 17:28:51 +00:00
parent d3b5f04f88
commit ca50973065

View File

@@ -44,6 +44,11 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
"""Tokenize a chat-formatted sample and build labels.
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
CRITICAL: After tokenization, replaces text-token sequences for special tokens
(like <|tool_call_start|>) with the actual special token IDs (like 128015).
apply_chat_template() renders these as regular text, but the model needs to learn
to emit the single special token, not the multi-token text sequence.
"""
messages = sample["messages"]
text = tokenizer.apply_chat_template(
@@ -59,52 +64,120 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
return_offsets_mapping=True,
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
input_ids = list(enc["input_ids"])
attention_mask = list(enc["attention_mask"])
labels = [-100] * len(input_ids)
# ── Replace text token sequences with actual special token IDs ──────────
# The tokenizer renders <|tool_call_start|> as [27, 91, 14506, 13735, 5011, 91, 397]
# But we need it as [128015] so the model learns to emit the special token.
# We do this by finding the text sequences and replacing them with the single ID.
TOKEN_REPLACEMENTS = [
# (text_string, special_token_id)
("<|tool_call_start|>", 128015),
("<|tool_call_end|>", 128016),
("<|tool_response_start|>", 128013),
("<|tool_response_end|>", 128014),
]
for text_str, special_id in TOKEN_REPLACEMENTS:
# Find all occurrences in the text
search_pos = 0
while True:
idx = text.find(text_str, search_pos)
if idx == -1:
break
# Find the token span that covers this text range
char_start = idx
char_end = idx + len(text_str)
tok_start = None
tok_end = None
for ti, (cs, ce) in enumerate(enc.get("offset_mapping", [])):
if cs >= char_start and tok_start is None:
tok_start = ti
if ce >= char_end:
tok_end = ti + 1
break
if tok_start is not None and tok_end is not None:
# Replace the multi-token sequence with the single special token
input_ids[tok_start] = special_id
# Remove the extra tokens
remove_count = tok_end - tok_start - 1
if remove_count > 0:
del input_ids[tok_start + 1:tok_start + 1 + remove_count]
del attention_mask[tok_start + 1:tok_start + 1 + remove_count]
del labels[tok_start + 1:tok_start + 1 + remove_count]
# Adjust labels list length
# We also need to re-adjust subsequent replacement offsets,
# so we recalculate from the modified lists
# Easiest: rebuild from scratch for each replacement
# But since we process sequentially, just re-encode
# Actually, let's just re-tokenize the modified text instead
search_pos = idx + len(text_str)
# The above in-place replacement is tricky with shifting indices.
# Simpler approach: replace in the text first, then re-tokenize.
# But special tokens can't be in the text for regular tokenization.
# So let's use a two-pass approach: tokenize the text without the special
# token strings, then insert the special token IDs at the right positions.
# Actually, the simplest correct approach: rebuild input_ids by encoding
# segments between the special token markers.
input_ids = _tokenize_with_special_tokens(text, tokenizer, TOKEN_REPLACEMENTS, max_length)
attention_mask = [1] * len(input_ids)
labels = [-100] * len(input_ids)
# ── Label masking: only train on assistant turns ─────────────────────────
# We need to find assistant turn boundaries in the token sequence.
# Re-decode to find positions (this is a bit wasteful but correct).
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=False)
ASSISTANT_MARKER = "<|im_start|>assistant"
END_MARKER = "<|im_end|>"
offsets = enc.get("offset_mapping", [])
if not offsets:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
# Find assistant turn boundaries by searching in the decoded text
# and mapping character offsets back to token positions.
# Since we can't easily do offset mapping on the rebuilt sequence,
# we use a simpler approach: scan for the im_start assistant marker tokens.
# Encode the markers
assistant_marker_ids = tokenizer.encode(ASSISTANT_MARKER, add_special_tokens=False)
end_marker_ids = tokenizer.encode(END_MARKER, add_special_tokens=False)
pos = 0
while True:
start_idx = text.find(ASSISTANT_MARKER, pos)
if start_idx == -1:
break
while pos < len(input_ids):
# Look for assistant marker
if input_ids[pos:pos+len(assistant_marker_ids)] == assistant_marker_ids:
# Found assistant turn start
content_start = pos + len(assistant_marker_ids)
# Skip newline after marker if present
if content_start < len(input_ids) and input_ids[content_start] == tokenizer.encode("\n", add_special_tokens=False)[0]:
content_start += 1
# Content starts after the marker + newline
content_start = start_idx + len(ASSISTANT_MARKER)
if content_start < len(text) and text[content_start] == "\n":
content_start += 1
# Find end marker
end_pos = content_start
while end_pos < len(input_ids):
if input_ids[end_pos:end_pos+len(end_marker_ids)] == end_marker_ids:
break
end_pos += 1
end_idx = text.find(END_MARKER, content_start)
if end_idx == -1:
span_end = len(text)
else:
span_end = end_idx + len(END_MARKER)
if end_pos < len(input_ids):
span_end = end_pos + len(end_marker_ids)
else:
span_end = len(input_ids)
tok_start = None
tok_end = None
for ti, (cs, ce) in enumerate(offsets):
if cs >= content_start and tok_start is None:
tok_start = ti
if ce >= span_end:
tok_end = ti + 1
break
if tok_start is not None and tok_end is not None:
for i in range(tok_start, min(tok_end, len(labels))):
# Label the assistant tokens (including markers for stability)
for i in range(pos, min(span_end, len(labels))):
labels[i] = input_ids[i]
pos = (end_idx if end_idx != -1 else span_end) + 1
pos = span_end
else:
pos += 1
return {
"input_ids": input_ids,
@@ -113,6 +186,58 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
}
def _tokenize_with_special_tokens(text, tokenizer, replacements, max_length):
"""Tokenize text but preserve special token IDs as single tokens.
Splits the text at special token markers, tokenizes each segment normally,
then reassembles with the special token IDs inserted.
"""
# Find all special token positions
segments = [] # list of (char_start, char_end, is_special, special_id_or_text)
search_pos = 0
for text_str, special_id in replacements:
idx = text.find(text_str, search_pos)
while idx != -1:
segments.append((idx, idx + len(text_str), True, special_id))
search_pos = idx + len(text_str)
idx = text.find(text_str, search_pos)
search_pos = 0 # reset for next replacement
if not segments:
# No special tokens found, just tokenize normally
return tokenizer.encode(text, truncation=True, max_length=max_length)
# Sort by position
segments.sort(key=lambda x: x[0])
# Build the token sequence
result_ids = []
prev_end = 0
for char_start, char_end, is_special, value in segments:
# Tokenize the text before this segment
if char_start > prev_end:
prefix_text = text[prev_end:char_start]
prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
result_ids.extend(prefix_ids)
if is_special:
result_ids.append(value) # Single special token ID
prev_end = char_end
# Remaining text after last segment
if prev_end < len(text):
suffix_text = text[prev_end:]
suffix_ids = tokenizer.encode(suffix_text, add_special_tokens=False)
result_ids.extend(suffix_ids)
# Truncate to max_length
if len(result_ids) > max_length:
result_ids = result_ids[:max_length]
return result_ids
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
parser.add_argument("--data-dir", type=str, default="/data")