#!/usr/bin/env python3 """ LoRA fine-tuning script for SmolLM3-3B tool-calling. Uses PEFT + transformers + accelerate. Runs inside the Docker container. Usage: python train_lora.py \ --data-dir /data \ --model HuggingFaceTB/SmolLM3-3B \ --output-dir /output \ --epochs 3 \ --batch-size 4 \ --lr 2e-4 """ import argparse import json from pathlib import Path import torch from datasets import Dataset from peft import LoraConfig, TaskType, get_peft_model from transformers import ( AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq, Trainer, TrainingArguments, ) def load_jsonl(path): samples = [] with open(path) as f: for line in f: line = line.strip() if line: samples.append(json.loads(line)) return samples def tokenize_for_training(sample, tokenizer, max_length=4096): """Tokenize a chat-formatted sample and build labels. Masks everything except assistant responses (labels = -100 for non-assistant tokens). CRITICAL: After tokenization, replaces text-token sequences for special tokens (like <|tool_call_start|>) with the actual special token IDs (like 128015). apply_chat_template() renders these as regular text, but the model needs to learn to emit the single special token, not the multi-token text sequence. """ messages = sample["messages"] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) enc = tokenizer( text, truncation=True, max_length=max_length, return_offsets_mapping=True, ) input_ids = list(enc["input_ids"]) attention_mask = list(enc["attention_mask"]) labels = [-100] * len(input_ids) # ── Replace text token sequences with actual special token IDs ────────── # The tokenizer renders <|tool_call_start|> as [27, 91, 14506, 13735, 5011, 91, 397] # But we need it as [128015] so the model learns to emit the special token. # We do this by finding the text sequences and replacing them with the single ID. TOKEN_REPLACEMENTS = [ # (text_string, special_token_id) ("<|tool_call_start|>", 128015), ("<|tool_call_end|>", 128016), ("<|tool_response_start|>", 128013), ("<|tool_response_end|>", 128014), ] # Rebuild input_ids with special tokens as single token IDs input_ids = _tokenize_with_special_tokens(text, tokenizer, TOKEN_REPLACEMENTS, max_length) attention_mask = [1] * len(input_ids) labels = [-100] * len(input_ids) # ── Label masking: only train on assistant turns ───────────────────────── # We need to find assistant turn boundaries in the token sequence. # Re-decode to find positions (this is a bit wasteful but correct). decoded_text = tokenizer.decode(input_ids, skip_special_tokens=False) ASSISTANT_MARKER = "<|im_start|>assistant" END_MARKER = "<|im_end|>" # Find assistant turn boundaries by searching in the decoded text # and mapping character offsets back to token positions. # Since we can't easily do offset mapping on the rebuilt sequence, # we use a simpler approach: scan for the im_start assistant marker tokens. # Encode the markers assistant_marker_ids = tokenizer.encode(ASSISTANT_MARKER, add_special_tokens=False) end_marker_ids = tokenizer.encode(END_MARKER, add_special_tokens=False) pos = 0 while pos < len(input_ids): # Look for assistant marker if input_ids[pos:pos+len(assistant_marker_ids)] == assistant_marker_ids: # Found assistant turn start content_start = pos + len(assistant_marker_ids) # Skip newline after marker if present if content_start < len(input_ids) and input_ids[content_start] == tokenizer.encode("\n", add_special_tokens=False)[0]: content_start += 1 # Find end marker end_pos = content_start while end_pos < len(input_ids): if input_ids[end_pos:end_pos+len(end_marker_ids)] == end_marker_ids: break end_pos += 1 if end_pos < len(input_ids): span_end = end_pos + len(end_marker_ids) else: span_end = len(input_ids) # Label the assistant tokens (including markers for stability) for i in range(pos, min(span_end, len(labels))): labels[i] = input_ids[i] pos = span_end else: pos += 1 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def _tokenize_with_special_tokens(text, tokenizer, replacements, max_length): """Tokenize text but preserve special token IDs as single tokens. Splits the text at special token markers, tokenizes each segment normally, then reassembles with the special token IDs inserted. """ # Find all special token positions segments = [] # list of (char_start, char_end, is_special, special_id_or_text) search_pos = 0 for text_str, special_id in replacements: idx = text.find(text_str, search_pos) while idx != -1: segments.append((idx, idx + len(text_str), True, special_id)) search_pos = idx + len(text_str) idx = text.find(text_str, search_pos) search_pos = 0 # reset for next replacement if not segments: # No special tokens found, just tokenize normally return tokenizer.encode(text, truncation=True, max_length=max_length) # Sort by position segments.sort(key=lambda x: x[0]) # Build the token sequence result_ids = [] prev_end = 0 for char_start, char_end, is_special, value in segments: # Tokenize the text before this segment if char_start > prev_end: prefix_text = text[prev_end:char_start] prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False) result_ids.extend(prefix_ids) if is_special: result_ids.append(value) # Single special token ID prev_end = char_end # Remaining text after last segment if prev_end < len(text): suffix_text = text[prev_end:] suffix_ids = tokenizer.encode(suffix_text, add_special_tokens=False) result_ids.extend(suffix_ids) # Truncate to max_length if len(result_ids) > max_length: result_ids = result_ids[:max_length] return result_ids def main(): parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling") parser.add_argument("--data-dir", type=str, default="/data") parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B") parser.add_argument("--output-dir", type=str, default="/output") parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--warmup-ratio", type=float, default=0.03) parser.add_argument("--max-length", type=int, default=4096) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--lora-alpha", type=int, default=32) parser.add_argument("--lora-dropout", type=float, default=0.05) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--fp16", action="store_true", default=False) parser.add_argument("--bf16", action="store_true", default=True) parser.add_argument("--resume-from", type=str, default=None) args = parser.parse_args() print(f"Loading tokenizer: {args.model}") tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading model: {args.model}") model = AutoModelForCausalLM.from_pretrained( args.model, trust_remote_code=True, torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32), device_map="auto", ) lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings ], bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() data_dir = Path(args.data_dir) train_data = load_jsonl(data_dir / "train.jsonl") val_data = load_jsonl(data_dir / "val.jsonl") print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}") # ── Verify tool-call tokens are in the training data ───────────────── print("Verifying tool-call token IDs in training data ...") verification_sample = train_data[0] v_text = tokenizer.apply_chat_template(verification_sample["messages"], tokenize=False) v_ids = tokenizer.encode(v_text) tc_start_found = 128015 in v_ids tc_end_found = 128016 in v_ids if tc_start_found and tc_end_found: print(f" ✓ Tool-call tokens verified in sample data (128015={tc_start_found}, 128016={tc_end_found})") else: print(f" ✗ WARNING: Tool-call tokens missing! (128015={tc_start_found}, 128016={tc_end_found})") print(f" ✗ Training may NOT teach the model to emit tool-call tokens.") print(f" ✗ Check prepare_data.py and the tokenizer chat template.") # Don't abort — let the user decide, but warn loudly print("Tokenizing training data ...") train_dataset = Dataset.from_list(train_data).map( lambda x: tokenize_for_training(x, tokenizer, args.max_length), remove_columns=["messages"], desc="Tokenizing train", ) val_dataset = Dataset.from_list(val_data).map( lambda x: tokenize_for_training(x, tokenizer, args.max_length), remove_columns=["messages"], desc="Tokenizing val", ) data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, padding=True, return_tensors="pt", ) training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, warmup_ratio=args.warmup_ratio, lr_scheduler_type="cosine", logging_steps=10, eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, bf16=args.bf16, fp16=args.fp16, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_torch_fused", seed=args.seed, report_to="none", dataloader_num_workers=4, dataloader_pin_memory=True, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator, ) print("Starting training ...") trainer.train(resume_from_checkpoint=args.resume_from) print(f"Saving LoRA adapter to {args.output_dir}/final") model.save_pretrained(f"{args.output_dir}/final") tokenizer.save_pretrained(f"{args.output_dir}/final") print("Done!") if __name__ == "__main__": main()