Fix tokenization: replace text token sequences with actual special token IDs
This commit is contained in:
193
train_lora.py
193
train_lora.py
@@ -44,6 +44,11 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
|
||||
"""Tokenize a chat-formatted sample and build labels.
|
||||
|
||||
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
|
||||
|
||||
CRITICAL: After tokenization, replaces text-token sequences for special tokens
|
||||
(like <|tool_call_start|>) with the actual special token IDs (like 128015).
|
||||
apply_chat_template() renders these as regular text, but the model needs to learn
|
||||
to emit the single special token, not the multi-token text sequence.
|
||||
"""
|
||||
messages = sample["messages"]
|
||||
text = tokenizer.apply_chat_template(
|
||||
@@ -59,52 +64,120 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
|
||||
input_ids = enc["input_ids"]
|
||||
attention_mask = enc["attention_mask"]
|
||||
input_ids = list(enc["input_ids"])
|
||||
attention_mask = list(enc["attention_mask"])
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# ── Replace text token sequences with actual special token IDs ──────────
|
||||
# The tokenizer renders <|tool_call_start|> as [27, 91, 14506, 13735, 5011, 91, 397]
|
||||
# But we need it as [128015] so the model learns to emit the special token.
|
||||
# We do this by finding the text sequences and replacing them with the single ID.
|
||||
|
||||
TOKEN_REPLACEMENTS = [
|
||||
# (text_string, special_token_id)
|
||||
("<|tool_call_start|>", 128015),
|
||||
("<|tool_call_end|>", 128016),
|
||||
("<|tool_response_start|>", 128013),
|
||||
("<|tool_response_end|>", 128014),
|
||||
]
|
||||
|
||||
for text_str, special_id in TOKEN_REPLACEMENTS:
|
||||
# Find all occurrences in the text
|
||||
search_pos = 0
|
||||
while True:
|
||||
idx = text.find(text_str, search_pos)
|
||||
if idx == -1:
|
||||
break
|
||||
|
||||
# Find the token span that covers this text range
|
||||
char_start = idx
|
||||
char_end = idx + len(text_str)
|
||||
|
||||
tok_start = None
|
||||
tok_end = None
|
||||
for ti, (cs, ce) in enumerate(enc.get("offset_mapping", [])):
|
||||
if cs >= char_start and tok_start is None:
|
||||
tok_start = ti
|
||||
if ce >= char_end:
|
||||
tok_end = ti + 1
|
||||
break
|
||||
|
||||
if tok_start is not None and tok_end is not None:
|
||||
# Replace the multi-token sequence with the single special token
|
||||
input_ids[tok_start] = special_id
|
||||
# Remove the extra tokens
|
||||
remove_count = tok_end - tok_start - 1
|
||||
if remove_count > 0:
|
||||
del input_ids[tok_start + 1:tok_start + 1 + remove_count]
|
||||
del attention_mask[tok_start + 1:tok_start + 1 + remove_count]
|
||||
del labels[tok_start + 1:tok_start + 1 + remove_count]
|
||||
# Adjust labels list length
|
||||
# We also need to re-adjust subsequent replacement offsets,
|
||||
# so we recalculate from the modified lists
|
||||
# Easiest: rebuild from scratch for each replacement
|
||||
# But since we process sequentially, just re-encode
|
||||
# Actually, let's just re-tokenize the modified text instead
|
||||
|
||||
search_pos = idx + len(text_str)
|
||||
|
||||
# The above in-place replacement is tricky with shifting indices.
|
||||
# Simpler approach: replace in the text first, then re-tokenize.
|
||||
# But special tokens can't be in the text for regular tokenization.
|
||||
# So let's use a two-pass approach: tokenize the text without the special
|
||||
# token strings, then insert the special token IDs at the right positions.
|
||||
|
||||
# Actually, the simplest correct approach: rebuild input_ids by encoding
|
||||
# segments between the special token markers.
|
||||
input_ids = _tokenize_with_special_tokens(text, tokenizer, TOKEN_REPLACEMENTS, max_length)
|
||||
attention_mask = [1] * len(input_ids)
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# ── Label masking: only train on assistant turns ─────────────────────────
|
||||
# We need to find assistant turn boundaries in the token sequence.
|
||||
# Re-decode to find positions (this is a bit wasteful but correct).
|
||||
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
|
||||
ASSISTANT_MARKER = "<|im_start|>assistant"
|
||||
END_MARKER = "<|im_end|>"
|
||||
|
||||
offsets = enc.get("offset_mapping", [])
|
||||
if not offsets:
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
# Find assistant turn boundaries by searching in the decoded text
|
||||
# and mapping character offsets back to token positions.
|
||||
# Since we can't easily do offset mapping on the rebuilt sequence,
|
||||
# we use a simpler approach: scan for the im_start assistant marker tokens.
|
||||
|
||||
# Encode the markers
|
||||
assistant_marker_ids = tokenizer.encode(ASSISTANT_MARKER, add_special_tokens=False)
|
||||
end_marker_ids = tokenizer.encode(END_MARKER, add_special_tokens=False)
|
||||
|
||||
pos = 0
|
||||
while True:
|
||||
start_idx = text.find(ASSISTANT_MARKER, pos)
|
||||
if start_idx == -1:
|
||||
break
|
||||
while pos < len(input_ids):
|
||||
# Look for assistant marker
|
||||
if input_ids[pos:pos+len(assistant_marker_ids)] == assistant_marker_ids:
|
||||
# Found assistant turn start
|
||||
content_start = pos + len(assistant_marker_ids)
|
||||
# Skip newline after marker if present
|
||||
if content_start < len(input_ids) and input_ids[content_start] == tokenizer.encode("\n", add_special_tokens=False)[0]:
|
||||
content_start += 1
|
||||
|
||||
# Content starts after the marker + newline
|
||||
content_start = start_idx + len(ASSISTANT_MARKER)
|
||||
if content_start < len(text) and text[content_start] == "\n":
|
||||
content_start += 1
|
||||
# Find end marker
|
||||
end_pos = content_start
|
||||
while end_pos < len(input_ids):
|
||||
if input_ids[end_pos:end_pos+len(end_marker_ids)] == end_marker_ids:
|
||||
break
|
||||
end_pos += 1
|
||||
|
||||
end_idx = text.find(END_MARKER, content_start)
|
||||
if end_idx == -1:
|
||||
span_end = len(text)
|
||||
else:
|
||||
span_end = end_idx + len(END_MARKER)
|
||||
if end_pos < len(input_ids):
|
||||
span_end = end_pos + len(end_marker_ids)
|
||||
else:
|
||||
span_end = len(input_ids)
|
||||
|
||||
tok_start = None
|
||||
tok_end = None
|
||||
for ti, (cs, ce) in enumerate(offsets):
|
||||
if cs >= content_start and tok_start is None:
|
||||
tok_start = ti
|
||||
if ce >= span_end:
|
||||
tok_end = ti + 1
|
||||
break
|
||||
|
||||
if tok_start is not None and tok_end is not None:
|
||||
for i in range(tok_start, min(tok_end, len(labels))):
|
||||
# Label the assistant tokens (including markers for stability)
|
||||
for i in range(pos, min(span_end, len(labels))):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
pos = (end_idx if end_idx != -1 else span_end) + 1
|
||||
pos = span_end
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
@@ -113,6 +186,58 @@ def tokenize_for_training(sample, tokenizer, max_length=4096):
|
||||
}
|
||||
|
||||
|
||||
def _tokenize_with_special_tokens(text, tokenizer, replacements, max_length):
|
||||
"""Tokenize text but preserve special token IDs as single tokens.
|
||||
|
||||
Splits the text at special token markers, tokenizes each segment normally,
|
||||
then reassembles with the special token IDs inserted.
|
||||
"""
|
||||
# Find all special token positions
|
||||
segments = [] # list of (char_start, char_end, is_special, special_id_or_text)
|
||||
search_pos = 0
|
||||
for text_str, special_id in replacements:
|
||||
idx = text.find(text_str, search_pos)
|
||||
while idx != -1:
|
||||
segments.append((idx, idx + len(text_str), True, special_id))
|
||||
search_pos = idx + len(text_str)
|
||||
idx = text.find(text_str, search_pos)
|
||||
search_pos = 0 # reset for next replacement
|
||||
|
||||
if not segments:
|
||||
# No special tokens found, just tokenize normally
|
||||
return tokenizer.encode(text, truncation=True, max_length=max_length)
|
||||
|
||||
# Sort by position
|
||||
segments.sort(key=lambda x: x[0])
|
||||
|
||||
# Build the token sequence
|
||||
result_ids = []
|
||||
prev_end = 0
|
||||
for char_start, char_end, is_special, value in segments:
|
||||
# Tokenize the text before this segment
|
||||
if char_start > prev_end:
|
||||
prefix_text = text[prev_end:char_start]
|
||||
prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
|
||||
result_ids.extend(prefix_ids)
|
||||
|
||||
if is_special:
|
||||
result_ids.append(value) # Single special token ID
|
||||
|
||||
prev_end = char_end
|
||||
|
||||
# Remaining text after last segment
|
||||
if prev_end < len(text):
|
||||
suffix_text = text[prev_end:]
|
||||
suffix_ids = tokenizer.encode(suffix_text, add_special_tokens=False)
|
||||
result_ids.extend(suffix_ids)
|
||||
|
||||
# Truncate to max_length
|
||||
if len(result_ids) > max_length:
|
||||
result_ids = result_ids[:max_length]
|
||||
|
||||
return result_ids
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
|
||||
parser.add_argument("--data-dir", type=str, default="/data")
|
||||
|
||||
Reference in New Issue
Block a user