2026-04-10 05:11:05 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""
|
|
|
|
|
LoRA fine-tuning script for SmolLM3-3B tool-calling.
|
|
|
|
|
|
2026-04-10 06:24:05 +00:00
|
|
|
Uses PEFT + transformers + accelerate. Runs inside the Docker container.
|
2026-04-10 05:11:05 +00:00
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
python train_lora.py \
|
2026-04-10 06:24:05 +00:00
|
|
|
--data-dir /data \
|
2026-04-10 05:11:05 +00:00
|
|
|
--model HuggingFaceTB/SmolLM3-3B \
|
2026-04-10 06:24:05 +00:00
|
|
|
--output-dir /output \
|
2026-04-10 05:11:05 +00:00
|
|
|
--epochs 3 \
|
|
|
|
|
--batch-size 4 \
|
|
|
|
|
--lr 2e-4
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import json
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from datasets import Dataset
|
|
|
|
|
from peft import LoraConfig, TaskType, get_peft_model
|
|
|
|
|
from transformers import (
|
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
|
AutoTokenizer,
|
|
|
|
|
DataCollatorForSeq2Seq,
|
|
|
|
|
Trainer,
|
|
|
|
|
TrainingArguments,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-04-10 06:24:05 +00:00
|
|
|
def load_jsonl(path):
|
2026-04-10 05:11:05 +00:00
|
|
|
samples = []
|
|
|
|
|
with open(path) as f:
|
|
|
|
|
for line in f:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if line:
|
|
|
|
|
samples.append(json.loads(line))
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
|
2026-04-10 06:24:05 +00:00
|
|
|
def tokenize_for_training(sample, tokenizer, max_length=4096):
|
2026-04-10 05:11:05 +00:00
|
|
|
"""Tokenize a chat-formatted sample and build labels.
|
|
|
|
|
|
|
|
|
|
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
|
|
|
|
|
"""
|
|
|
|
|
messages = sample["messages"]
|
|
|
|
|
text = tokenizer.apply_chat_template(
|
|
|
|
|
messages,
|
|
|
|
|
tokenize=False,
|
|
|
|
|
add_generation_prompt=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
enc = tokenizer(
|
|
|
|
|
text,
|
|
|
|
|
truncation=True,
|
|
|
|
|
max_length=max_length,
|
|
|
|
|
return_offsets_mapping=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
input_ids = enc["input_ids"]
|
|
|
|
|
attention_mask = enc["attention_mask"]
|
|
|
|
|
labels = [-100] * len(input_ids)
|
|
|
|
|
|
2026-04-10 06:24:05 +00:00
|
|
|
ASSISTANT_MARKER = "<|im_start|>assistant"
|
2026-04-10 05:11:05 +00:00
|
|
|
END_MARKER = "<|im_end|>"
|
|
|
|
|
|
|
|
|
|
offsets = enc.get("offset_mapping", [])
|
|
|
|
|
if not offsets:
|
|
|
|
|
return {
|
|
|
|
|
"input_ids": input_ids,
|
|
|
|
|
"attention_mask": attention_mask,
|
|
|
|
|
"labels": labels,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pos = 0
|
|
|
|
|
while True:
|
|
|
|
|
start_idx = text.find(ASSISTANT_MARKER, pos)
|
|
|
|
|
if start_idx == -1:
|
|
|
|
|
break
|
|
|
|
|
|
2026-04-10 06:24:05 +00:00
|
|
|
# Content starts after the marker + newline
|
2026-04-10 05:11:05 +00:00
|
|
|
content_start = start_idx + len(ASSISTANT_MARKER)
|
2026-04-10 06:24:05 +00:00
|
|
|
if content_start < len(text) and text[content_start] == "\n":
|
|
|
|
|
content_start += 1
|
|
|
|
|
|
2026-04-10 05:11:05 +00:00
|
|
|
end_idx = text.find(END_MARKER, content_start)
|
|
|
|
|
if end_idx == -1:
|
|
|
|
|
span_end = len(text)
|
|
|
|
|
else:
|
|
|
|
|
span_end = end_idx + len(END_MARKER)
|
|
|
|
|
|
|
|
|
|
tok_start = None
|
|
|
|
|
tok_end = None
|
|
|
|
|
for ti, (cs, ce) in enumerate(offsets):
|
|
|
|
|
if cs >= content_start and tok_start is None:
|
|
|
|
|
tok_start = ti
|
|
|
|
|
if ce >= span_end:
|
|
|
|
|
tok_end = ti + 1
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if tok_start is not None and tok_end is not None:
|
|
|
|
|
for i in range(tok_start, min(tok_end, len(labels))):
|
|
|
|
|
labels[i] = input_ids[i]
|
|
|
|
|
|
|
|
|
|
pos = (end_idx if end_idx != -1 else span_end) + 1
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"input_ids": input_ids,
|
|
|
|
|
"attention_mask": attention_mask,
|
|
|
|
|
"labels": labels,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
|
2026-04-10 06:24:05 +00:00
|
|
|
parser.add_argument("--data-dir", type=str, default="/data")
|
2026-04-10 05:11:05 +00:00
|
|
|
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
2026-04-10 06:24:05 +00:00
|
|
|
parser.add_argument("--output-dir", type=str, default="/output")
|
2026-04-10 05:11:05 +00:00
|
|
|
parser.add_argument("--epochs", type=int, default=3)
|
|
|
|
|
parser.add_argument("--batch-size", type=int, default=4)
|
|
|
|
|
parser.add_argument("--grad-accum", type=int, default=4)
|
|
|
|
|
parser.add_argument("--lr", type=float, default=2e-4)
|
|
|
|
|
parser.add_argument("--warmup-ratio", type=float, default=0.03)
|
|
|
|
|
parser.add_argument("--max-length", type=int, default=4096)
|
|
|
|
|
parser.add_argument("--lora-r", type=int, default=16)
|
|
|
|
|
parser.add_argument("--lora-alpha", type=int, default=32)
|
|
|
|
|
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
|
|
|
parser.add_argument("--fp16", action="store_true", default=False)
|
|
|
|
|
parser.add_argument("--bf16", action="store_true", default=True)
|
|
|
|
|
parser.add_argument("--resume-from", type=str, default=None)
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
print(f"Loading tokenizer: {args.model}")
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
print(f"Loading model: {args.model}")
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
args.model,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32),
|
|
|
|
|
device_map="auto",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
lora_config = LoraConfig(
|
|
|
|
|
task_type=TaskType.CAUSAL_LM,
|
|
|
|
|
r=args.lora_r,
|
|
|
|
|
lora_alpha=args.lora_alpha,
|
|
|
|
|
lora_dropout=args.lora_dropout,
|
|
|
|
|
target_modules=[
|
|
|
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
|
|
|
"gate_proj", "up_proj", "down_proj",
|
2026-04-10 17:20:06 +00:00
|
|
|
"embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings
|
2026-04-10 05:11:05 +00:00
|
|
|
],
|
|
|
|
|
bias="none",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
|
|
|
model.print_trainable_parameters()
|
|
|
|
|
|
|
|
|
|
data_dir = Path(args.data_dir)
|
|
|
|
|
train_data = load_jsonl(data_dir / "train.jsonl")
|
|
|
|
|
val_data = load_jsonl(data_dir / "val.jsonl")
|
|
|
|
|
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
|
|
|
|
|
|
2026-04-10 17:20:06 +00:00
|
|
|
# ── Verify tool-call tokens are in the training data ─────────────────
|
|
|
|
|
print("Verifying tool-call token IDs in training data ...")
|
|
|
|
|
verification_sample = train_data[0]
|
|
|
|
|
v_text = tokenizer.apply_chat_template(verification_sample["messages"], tokenize=False)
|
|
|
|
|
v_ids = tokenizer.encode(v_text)
|
|
|
|
|
tc_start_found = 128015 in v_ids
|
|
|
|
|
tc_end_found = 128016 in v_ids
|
|
|
|
|
if tc_start_found and tc_end_found:
|
|
|
|
|
print(f" ✓ Tool-call tokens verified in sample data (128015={tc_start_found}, 128016={tc_end_found})")
|
|
|
|
|
else:
|
|
|
|
|
print(f" ✗ WARNING: Tool-call tokens missing! (128015={tc_start_found}, 128016={tc_end_found})")
|
|
|
|
|
print(f" ✗ Training may NOT teach the model to emit tool-call tokens.")
|
|
|
|
|
print(f" ✗ Check prepare_data.py and the tokenizer chat template.")
|
|
|
|
|
# Don't abort — let the user decide, but warn loudly
|
|
|
|
|
|
2026-04-10 05:11:05 +00:00
|
|
|
print("Tokenizing training data ...")
|
|
|
|
|
train_dataset = Dataset.from_list(train_data).map(
|
|
|
|
|
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
|
|
|
|
|
remove_columns=["messages"],
|
|
|
|
|
desc="Tokenizing train",
|
|
|
|
|
)
|
|
|
|
|
val_dataset = Dataset.from_list(val_data).map(
|
|
|
|
|
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
|
|
|
|
|
remove_columns=["messages"],
|
|
|
|
|
desc="Tokenizing val",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
padding=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
|
|
output_dir=args.output_dir,
|
|
|
|
|
num_train_epochs=args.epochs,
|
|
|
|
|
per_device_train_batch_size=args.batch_size,
|
|
|
|
|
per_device_eval_batch_size=args.batch_size,
|
|
|
|
|
gradient_accumulation_steps=args.grad_accum,
|
|
|
|
|
learning_rate=args.lr,
|
|
|
|
|
warmup_ratio=args.warmup_ratio,
|
|
|
|
|
lr_scheduler_type="cosine",
|
|
|
|
|
logging_steps=10,
|
|
|
|
|
eval_strategy="steps",
|
|
|
|
|
eval_steps=100,
|
|
|
|
|
save_strategy="steps",
|
|
|
|
|
save_steps=100,
|
|
|
|
|
save_total_limit=3,
|
|
|
|
|
load_best_model_at_end=True,
|
|
|
|
|
metric_for_best_model="eval_loss",
|
|
|
|
|
greater_is_better=False,
|
|
|
|
|
bf16=args.bf16,
|
|
|
|
|
fp16=args.fp16,
|
|
|
|
|
gradient_checkpointing=True,
|
|
|
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
|
|
|
optim="adamw_torch_fused",
|
|
|
|
|
seed=args.seed,
|
|
|
|
|
report_to="none",
|
|
|
|
|
dataloader_num_workers=4,
|
|
|
|
|
dataloader_pin_memory=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
|
model=model,
|
|
|
|
|
args=training_args,
|
|
|
|
|
train_dataset=train_dataset,
|
|
|
|
|
eval_dataset=val_dataset,
|
|
|
|
|
data_collator=data_collator,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
print("Starting training ...")
|
|
|
|
|
trainer.train(resume_from_checkpoint=args.resume_from)
|
|
|
|
|
|
|
|
|
|
print(f"Saving LoRA adapter to {args.output_dir}/final")
|
|
|
|
|
model.save_pretrained(f"{args.output_dir}/final")
|
|
|
|
|
tokenizer.save_pretrained(f"{args.output_dir}/final")
|
|
|
|
|
|
2026-04-10 06:24:05 +00:00
|
|
|
print("Done!")
|
2026-04-10 05:11:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|