Files
smollora/train_lora.py

331 lines
12 KiB
Python

#!/usr/bin/env python3
"""
LoRA fine-tuning script for SmolLM3-3B tool-calling.
Uses PEFT + transformers + accelerate. Runs inside the Docker container.
Usage:
python train_lora.py \
--data-dir /data \
--model HuggingFaceTB/SmolLM3-3B \
--output-dir /output \
--epochs 3 \
--batch-size 4 \
--lr 2e-4
"""
import argparse
import json
from pathlib import Path
import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Trainer,
TrainingArguments,
)
def load_jsonl(path):
samples = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
samples.append(json.loads(line))
return samples
def tokenize_for_training(sample, tokenizer, max_length=4096):
"""Tokenize a chat-formatted sample and build labels.
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
CRITICAL: After tokenization, replaces text-token sequences for special tokens
(like <|tool_call_start|>) with the actual special token IDs (like 128015).
apply_chat_template() renders these as regular text, but the model needs to learn
to emit the single special token, not the multi-token text sequence.
"""
messages = sample["messages"]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
enc = tokenizer(
text,
truncation=True,
max_length=max_length,
return_offsets_mapping=True,
)
input_ids = list(enc["input_ids"])
attention_mask = list(enc["attention_mask"])
labels = [-100] * len(input_ids)
# ── Replace text token sequences with actual special token IDs ──────────
# The tokenizer renders <|tool_call_start|> as [27, 91, 14506, 13735, 5011, 91, 397]
# But we need it as [128015] so the model learns to emit the special token.
# We do this by finding the text sequences and replacing them with the single ID.
TOKEN_REPLACEMENTS = [
# (text_string, special_token_id)
("<|tool_call_start|>", 128015),
("<|tool_call_end|>", 128016),
("<|tool_response_start|>", 128013),
("<|tool_response_end|>", 128014),
]
# Rebuild input_ids with special tokens as single token IDs
input_ids = _tokenize_with_special_tokens(text, tokenizer, TOKEN_REPLACEMENTS, max_length)
attention_mask = [1] * len(input_ids)
labels = [-100] * len(input_ids)
# ── Label masking: only train on assistant turns ─────────────────────────
# We need to find assistant turn boundaries in the token sequence.
# Re-decode to find positions (this is a bit wasteful but correct).
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=False)
ASSISTANT_MARKER = "<|im_start|>assistant"
END_MARKER = "<|im_end|>"
# Find assistant turn boundaries by searching in the decoded text
# and mapping character offsets back to token positions.
# Since we can't easily do offset mapping on the rebuilt sequence,
# we use a simpler approach: scan for the im_start assistant marker tokens.
# Encode the markers
assistant_marker_ids = tokenizer.encode(ASSISTANT_MARKER, add_special_tokens=False)
end_marker_ids = tokenizer.encode(END_MARKER, add_special_tokens=False)
pos = 0
while pos < len(input_ids):
# Look for assistant marker
if input_ids[pos:pos+len(assistant_marker_ids)] == assistant_marker_ids:
# Found assistant turn start
content_start = pos + len(assistant_marker_ids)
# Skip newline after marker if present
if content_start < len(input_ids) and input_ids[content_start] == tokenizer.encode("\n", add_special_tokens=False)[0]:
content_start += 1
# Find end marker
end_pos = content_start
while end_pos < len(input_ids):
if input_ids[end_pos:end_pos+len(end_marker_ids)] == end_marker_ids:
break
end_pos += 1
if end_pos < len(input_ids):
span_end = end_pos + len(end_marker_ids)
else:
span_end = len(input_ids)
# Label the assistant tokens (including markers for stability)
for i in range(pos, min(span_end, len(labels))):
labels[i] = input_ids[i]
pos = span_end
else:
pos += 1
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def _tokenize_with_special_tokens(text, tokenizer, replacements, max_length):
"""Tokenize text but preserve special token IDs as single tokens.
Splits the text at special token markers, tokenizes each segment normally,
then reassembles with the special token IDs inserted.
"""
# Find all special token positions
segments = [] # list of (char_start, char_end, is_special, special_id_or_text)
search_pos = 0
for text_str, special_id in replacements:
idx = text.find(text_str, search_pos)
while idx != -1:
segments.append((idx, idx + len(text_str), True, special_id))
search_pos = idx + len(text_str)
idx = text.find(text_str, search_pos)
search_pos = 0 # reset for next replacement
if not segments:
# No special tokens found, just tokenize normally
return tokenizer.encode(text, truncation=True, max_length=max_length)
# Sort by position
segments.sort(key=lambda x: x[0])
# Build the token sequence
result_ids = []
prev_end = 0
for char_start, char_end, is_special, value in segments:
# Tokenize the text before this segment
if char_start > prev_end:
prefix_text = text[prev_end:char_start]
prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
result_ids.extend(prefix_ids)
if is_special:
result_ids.append(value) # Single special token ID
prev_end = char_end
# Remaining text after last segment
if prev_end < len(text):
suffix_text = text[prev_end:]
suffix_ids = tokenizer.encode(suffix_text, add_special_tokens=False)
result_ids.extend(suffix_ids)
# Truncate to max_length
if len(result_ids) > max_length:
result_ids = result_ids[:max_length]
return result_ids
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
parser.add_argument("--data-dir", type=str, default="/data")
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
parser.add_argument("--output-dir", type=str, default="/output")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--grad-accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--warmup-ratio", type=float, default=0.03)
parser.add_argument("--max-length", type=int, default=4096)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--lora-dropout", type=float, default=0.05)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--resume-from", type=str, default=None)
args = parser.parse_args()
print(f"Loading tokenizer: {args.model}")
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model: {args.model}")
model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32),
device_map="auto",
)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings
],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
data_dir = Path(args.data_dir)
train_data = load_jsonl(data_dir / "train.jsonl")
val_data = load_jsonl(data_dir / "val.jsonl")
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
# ── Verify tool-call tokens are in the training data ─────────────────
print("Verifying tool-call token IDs in training data ...")
verification_sample = train_data[0]
v_text = tokenizer.apply_chat_template(verification_sample["messages"], tokenize=False)
v_ids = tokenizer.encode(v_text)
tc_start_found = 128015 in v_ids
tc_end_found = 128016 in v_ids
if tc_start_found and tc_end_found:
print(f" ✓ Tool-call tokens verified in sample data (128015={tc_start_found}, 128016={tc_end_found})")
else:
print(f" ✗ WARNING: Tool-call tokens missing! (128015={tc_start_found}, 128016={tc_end_found})")
print(f" ✗ Training may NOT teach the model to emit tool-call tokens.")
print(f" ✗ Check prepare_data.py and the tokenizer chat template.")
# Don't abort — let the user decide, but warn loudly
print("Tokenizing training data ...")
train_dataset = Dataset.from_list(train_data).map(
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
remove_columns=["messages"],
desc="Tokenizing train",
)
val_dataset = Dataset.from_list(val_data).map(
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
remove_columns=["messages"],
desc="Tokenizing val",
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
return_tensors="pt",
)
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type="cosine",
logging_steps=10,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
bf16=args.bf16,
fp16=args.fp16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="adamw_torch_fused",
seed=args.seed,
report_to="none",
dataloader_num_workers=4,
dataloader_pin_memory=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
)
print("Starting training ...")
trainer.train(resume_from_checkpoint=args.resume_from)
print(f"Saving LoRA adapter to {args.output_dir}/final")
model.save_pretrained(f"{args.output_dir}/final")
tokenizer.save_pretrained(f"{args.output_dir}/final")
print("Done!")
if __name__ == "__main__":
main()