commit 82348341b0f3f22d0272cb348eb63ddc9b59e9e0 Author: Jinx Date: Fri Apr 10 05:11:05 2026 +0000 Initial LoRA training setup for SmolLM3-3B tool calling diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7a8c891 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,22 @@ +FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel + +# System deps +RUN apt-get update && apt-get install -y --no-install-recommends \ + git ninja-build packaging wget curl \ + && rm -rf /var/lib/apt/lists/* + +# Python deps +COPY requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -r /tmp/requirements.txt + +# Copy scripts +WORKDIR /app +COPY prepare_data.py /app/ +COPY train_lora.py /app/ +COPY run.sh /app/ + +# Data and output dirs +RUN mkdir -p /data/processed /data/lora-output /data/models + +# Default: run the full pipeline +ENTRYPOINT ["/app/run.sh"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..9017079 --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# SmolLM3-3B LoRA — Tool Calling Fine-Tune + +LoRA adapter training to make SmolLM3-3B a tool-calling savant. + +## Quick Start + +```bash +# Build +docker build -t smollora . + +# Run full pipeline (prepare data + train) +docker run --gpus all \ + -v /path/on/host/output:/data/lora-output \ + smollora + +# Skip data prep if you already have processed data +docker run --gpus all \ + -e SKIP_PREP=1 \ + -v /path/on/host/processed:/data/processed \ + -v /path/on/host/output:/data/lora-output \ + smollora +``` + +## Environment Variables + +| Var | Default | Description | +|-----|---------|-------------| +| `MODEL` | `HuggingFaceTB/SmolLM3-3B` | Base model (HF repo or local path) | +| `DATA_DIR` | `/data/processed` | Processed data directory | +| `OUTPUT_DIR` | `/data/lora-output` | Training output directory | +| `EPOCHS` | `3` | Training epochs | +| `BATCH_SIZE` | `4` | Per-device batch size | +| `LR` | `2e-4` | Learning rate | +| `LORA_R` | `16` | LoRA rank | +| `MAX_LENGTH` | `4096` | Max sequence length | +| `SKIP_PREP` | `0` | Set to `1` to skip data preparation | + +## Datasets + +Three datasets combined and converted to SmolLM3's native token format: + +1. **interstellarninja/tool-calls-multiturn** — Multi-turn tool calling conversations +2. **NousResearch/Hermes-Function-Calling-V1** — Hermes-format function calling +3. **Salesforce/xLAM-function-calling-60k** — Large-scale function calling (60k samples) + +Only conversations containing tool calls are kept. All are normalized to SmolLM3's special tokens: +- Tool calls → `startPos`/`endPos` (token IDs 128002/128016) +- Tool responses → `eni`/`eni_result` (token IDs 128013/128014) + +## LoRA Configuration + +- **Rank:** 16 +- **Alpha:** 32 +- **Target modules:** q/k/v/o projections + gate/up/down MLP +- **Dropout:** 0.05 +- **Scheduler:** Cosine with 3% warmup +- **Optimizer:** AdamW (fused) +- **Gradient checkpointing:** Enabled + +## Output + +The trained adapter is saved to `$OUTPUT_DIR/final/`. To use with vLLM: + +```bash +# Merge adapter into base model (recommended for vLLM) +python -m peft import PeftModel +# Or pass the adapter path directly with --enable-lora +``` + +## SSH Deployment + +```bash +# On GPU box, after SSH-ing in: +docker run --gpus all -v ~/smol-data:/data smollora + +# Or with local model cache: +docker run --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + -v ~/smol-data:/data \ + smollora +``` diff --git a/model-files/chat_template.jinja b/model-files/chat_template.jinja new file mode 100644 index 0000000..2ac027e --- /dev/null +++ b/model-files/chat_template.jinja @@ -0,0 +1,102 @@ +{# ───── defaults ───── #} +{%- if enable_thinking is not defined -%} +{%- set enable_thinking = true -%} +{%- endif -%} + +{# ───── reasoning mode ───── #} +{%- if enable_thinking -%} + {%- set reasoning_mode = "/think" -%} +{%- else -%} + {%- set reasoning_mode = "/no_think" -%} +{%- endif -%} + +{# ───── header (system message) ───── #} +{{- "<|im_start|>system\n" -}} + +{%- if messages[0].role == "system" -%} + {%- set system_message = messages[0].content -%} + {%- if "/no_think" in system_message -%} + {%- set reasoning_mode = "/no_think" -%} + {%- elif "/think" in system_message -%} + {%- set reasoning_mode = "/think" -%} + {%- endif -%} + {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%} +{%- endif -%} + +{%- if "/system_override" in system_message -%} + {{- custom_instructions.replace("/system_override", "").rstrip() -}} +{%- else -%} + {{- "## Metadata\n\n" -}} + {{- "Knowledge Cutoff Date: June 2025\n" -}} + {%- set today = strftime_now("%d %B %Y") -%} + {{- "Today Date: " ~ today ~ "\n" -}} + {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}} + + {{- "## Custom Instructions\n\n" -}} + {%- if custom_instructions -%} + {{- custom_instructions + "\n\n" -}} + {%- elif reasoning_mode == "/think" -%} + {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}} + {%- else -%} + {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}} + {%- endif -%} + + {%- if xml_tools or python_tools or tools -%} + {{- "### Tools\n\n" -}} + {%- if xml_tools or tools -%} + {%- if tools -%} + {%- set xml_tools = tools -%} + {%- endif -%} + {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within XML tags:\n\n\n") -%} + {%- for tool in xml_tools[:] -%} + {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | tojson) ~ "\n" -%} + {%- endfor -%} + {%- set xml_tool_string = ns.xml_tool_string + "\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" -%} + {{- xml_tool_string -}} + {%- endif -%} + {%- if python_tools -%} + {%- set ns = namespace(python_tool_string="You may call one or more functions as python tools.\n\n") -%} + {%- for tool in python_tools[:] -%} + {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%} + {%- endfor -%} + {%- set python_tool_string = ns.python_tool_string + "\n\nThe state persists between code executions." -%} + {{- python_tool_string -}} + {%- endif -%} + {{- "\n\n" -}} + {%- endif -%} +{%- endif -%} +{{- "<|im_end|>\n" -}} + +{# ───── main loop ───── #} +{%- for message in messages -%} + {%- if message.role == "user" -%} + {{ "<|im_start|>user\n" + message.content + "<|im_end|>\n" }} + {%- elif message.role == "assistant" -%} + {% generation %} + {%- if message.tool_calls -%} + {%- set ns = namespace(tc_text="") -%} + {%- for tc in message.tool_calls -%} + {%- set ns.tc_text = ns.tc_text ~ "\n{\"name\": \"" ~ tc.function.name ~ "\", \"arguments\": " ~ tc.function.arguments ~ "}\n" -%} + {%- endfor -%} + {{ "<|im_start|>assistant\n" ~ (message.content if message.content is string else "") ~ ns.tc_text ~ "<|im_end|>\n" }} + {%- else -%} + {%- if reasoning_mode == "/think" -%} + {{ "<|im_start|>assistant\n\n" ~ (message.content if message.content is string else "") ~ "\n<|im_end|>\n" }} + {%- else -%} + {{ "<|im_start|>assistant\n" ~ (message.content if message.content is string else "") ~ "<|im_end|>\n" }} + {%- endif -%} + {%- endif -%} + {% endgeneration %} + {%- elif message.role == "tool" -%} + {{ "<|im_start|>user\n\n" ~ (message.content if message.content is string else "") ~ "\n<|im_end|>\n" }} + {%- endif -%} +{%- endfor -%} + +{# ───── generation prompt ───── #} +{%- if add_generation_prompt -%} + {%- if reasoning_mode == "/think" -%} + {{ "<|im_start|>assistant\n\n" }} + {%- else -%} + {{ "<|im_start|>assistant\n" }} + {%- endif -%} +{%- endif -%} diff --git a/model-files/gen_template.py b/model-files/gen_template.py new file mode 100644 index 0000000..81d299b --- /dev/null +++ b/model-files/gen_template.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +"""Generate the PRODUCTION fixed chat_template.jinja for SmolLM3-3B. + +v2: Fixed thinking mode direction - /think now opens unga... tags + in the generation prompt so the model actually generates reasoning. +""" +from transformers import AutoTokenizer +tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B") + +THINK_S = tok.decode([128002]) +THINK_E = tok.decode([128003]) +RESP_S = tok.decode([128013]) +RESP_E = tok.decode([128014]) +TC_S = tok.decode([128015]) +TC_E = tok.decode([128016]) + +T = [] + +# ─── defaults & system header ─── +T.append(r"""{# ───── defaults ───── #} +{%- if enable_thinking is not defined -%} +{%- set enable_thinking = true -%} +{%- endif -%} + +{# ───── reasoning mode ───── #} +{%- if enable_thinking -%} + {%- set reasoning_mode = "/think" -%} +{%- else -%} + {%- set reasoning_mode = "/no_think" -%} +{%- endif -%} + +{# ───── header (system message) ───── #} +{{- "<|im_start|>system\n" -}} + +{%- if messages[0].role == "system" -%} + {%- set system_message = messages[0].content -%} + {%- if "/no_think" in system_message -%} + {%- set reasoning_mode = "/no_think" -%} + {%- elif "/think" in system_message -%} + {%- set reasoning_mode = "/think" -%} + {%- endif -%} + {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%} +{%- endif -%} + +{%- if "/system_override" in system_message -%} + {{- custom_instructions.replace("/system_override", "").rstrip() -}} +{%- else -%} + {{- "## Metadata\n\n" -}} + {{- "Knowledge Cutoff Date: June 2025\n" -}} + {%- set today = strftime_now("%d %B %Y") -%} + {{- "Today Date: " ~ today ~ "\n" -}} + {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}} + + {{- "## Custom Instructions\n\n" -}} + {%- if custom_instructions -%} + {{- custom_instructions + "\n\n" -}} + {%- elif reasoning_mode == "/think" -%} + {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}} + {%- else -%} + {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}} + {%- endif -%} + + {%- if xml_tools or python_tools or tools -%} + {{- "### Tools\n\n" -}} + {%- if xml_tools or tools -%} + {%- if tools -%} + {%- set xml_tools = tools -%} + {%- endif -%} + {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within XML tags:\n\n\n") -%} + {%- for tool in xml_tools[:] -%} + {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | tojson) ~ "\n" -%} + {%- endfor -%}""") + +# Tool calling format with special tokens +T.append('\n {%- set xml_tool_string = ns.xml_tool_string + "\\n\\nFor each function call, return a json object with function name and arguments within ' + TC_S + ' XML tags:\\n' + TC_S + '\\n{\\"name\\": , \\"arguments\\": }\\n' + TC_E + '" -%}\n') + +T.append(r""" {{- xml_tool_string -}} + {%- endif -%} + {%- if python_tools -%} + {%- set ns = namespace(python_tool_string="You may call one or more functions as python tools.\n\n") -%} + {%- for tool in python_tools[:] -%} + {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%} + {%- endfor -%} + {%- set python_tool_string = ns.python_tool_string + "\n\nThe state persists between code executions." -%} + {{- python_tool_string -}} + {%- endif -%} + {{- "\n\n" -}} + {%- endif -%} +{%- endif -%} +{{- "<|im_end|>\n" -}}""") + +# ─── Main loop ─── +T.append(r""" + +{# ───── main loop ───── #} +{%- for message in messages -%} + {%- if message.role == "user" -%} + {{ "<|im_start|>user\n" + message.content + "<|im_end|>\n" }} + {%- elif message.role == "assistant" -%} + {% generation %} + {%- if message.tool_calls -%}""") + +# FIX: Render tool calls with TC_S/TC_E tokens +T.append('\n {%- set ns = namespace(tc_text="") -%}\n {%- for tc in message.tool_calls -%}\n {%- set ns.tc_text = ns.tc_text ~ "' + TC_S + '\\n{\\"name\\": \\"" ~ tc.function.name ~ "\\", \\"arguments\\": " ~ tc.function.arguments ~ "}\\n' + TC_E + '" -%}\n {%- endfor -%}\n {{ "<|im_start|>assistant\\n" ~ (message.content if message.content is string else "") ~ ns.tc_text ~ "<|im_end|>\\n" }}\n') + +T.append(r""" {%- else -%}""") + +# FIX v2: /think = think tags, /no_think = plain text (CORRECT direction now) +T.append('\n {%- if reasoning_mode == "/think" -%}\n {{ "<|im_start|>assistant\\n' + THINK_S + '\\n" ~ (message.content if message.content is string else "") ~ "\\n' + THINK_E + '<|im_end|>\\n" }}\n {%- else -%}\n {{ "<|im_start|>assistant\\n" ~ (message.content if message.content is string else "") ~ "<|im_end|>\\n" }}\n {%- endif -%}\n') + +T.append(r""" {%- endif -%} + {% endgeneration %}""") + +# FIX: Tool role with RESP_S/RESP_E tokens +T.append('\n {%- elif message.role == "tool" -%}\n {{ "<|im_start|>user\\n' + RESP_S + '\\n" ~ (message.content if message.content is string else "") ~ "\\n' + RESP_E + '<|im_end|>\\n" }}\n') + +T.append(r""" {%- endif -%} +{%- endfor -%}""") + +# ─── Generation prompt ─── +# FIX v2: /think opens unga... so model generates reasoning, /no_think is bare +T.append('\n\n{# ───── generation prompt ───── #}\n{%- if add_generation_prompt -%}\n {%- if reasoning_mode == "/think" -%}\n {{ "<|im_start|>assistant\\n' + THINK_S + '\\n" }}\n {%- else -%}\n {{ "<|im_start|>assistant\\n" }}\n {%- endif -%}\n{%- endif -%}\n') + +template = ''.join(T) + +with open('/root/chat_template.jinja', 'w', encoding='utf-8') as f: + f.write(template) + +print("Production template v2 written to /root/chat_template.jinja") +print(f"Length: {len(template)} bytes") diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000..fa4419a --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +""" +Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning. + +Combines three datasets: + 1. interstellarninja/tool-calls-multiturn + 2. NousResearch/Hermes-Function-Calling-V1 + 3. Salesforce/xLAM-function-calling-60k + +Converts all to SmolLM3's native chat format with proper special tokens: + - Tool calls wrapped in startPos/endPos tokens (IDs 128002/128016) + - Tool responses wrapped in eni/eni_result tokens (IDs 128013/128014) + - Thinking wrapped in think_start/think_end tags + +Output: train.jsonl, val.jsonl (tokenized & raw) +""" + +import json +import random +import re +from pathlib import Path + +from datasets import load_dataset + +# SmolLM3 special tokens (match the fixed chat_template.jinja) +TOOL_CALL_START = "<|tool_call_start|>" # token 128002 +TOOL_CALL_END = "<|tool_call_end|>" # token 128016 +TOOL_RESP_START = "<|tool_response_start|>" # token 128013 +TOOL_RESP_END = "<|tool_response_end|>" # token 128014 +THINK_START = "" +THINK_END = "" + +VAL_FRACTION = 0.05 +SEED = 42 + + +def render_tool_calls(tool_calls: list[dict]) -> str: + """Render tool_calls list into SmolLM3's native format.""" + parts = [] + for tc in tool_calls: + name = tc["function"]["name"] + args = tc["function"]["arguments"] + if isinstance(args, str): + args_str = args + else: + args_str = json.dumps(args, ensure_ascii=False) + parts.append(f'{{"name": "{name}", "arguments": {args_str}}}') + body = "\n".join(parts) + return f"{TOOL_CALL_START}\n{body}\n{TOOL_CALL_END}" + + +def render_tool_response(content: str) -> str: + """Wrap tool response content in SmolLM3's tool_response tokens.""" + return f"{TOOL_RESP_START}\n{content}\n{TOOL_RESP_END}" + + +def convert_openai_messages(messages: list[dict], tools: list[dict] | None = None) -> list[dict]: + """Convert standard OpenAI-format messages to SmolLM3 native format. + + Transforms: + - assistant.tool_calls → content with startPos/endPos tokens + - tool role messages → user role with eni/eni_result tokens + - Adds system prompt with tool definitions if tools present + """ + converted = [] + + # Build system message with tool defs if present + if tools: + tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools) + system_content = ( + "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" + "### Tools\n\n" + "You may call one or more functions to assist with the user query.\n" + "You are provided with function signatures within XML tags:\n\n" + f"\n{tool_defs}\n\n\n" + 'For each function call, return a json object with function name and arguments within ' + f'{TOOL_CALL_START} {TOOL_CALL_END} tags:\n' + f'{TOOL_CALL_START}\n{{"name": , "arguments": }}\n{TOOL_CALL_END}' + ) + converted.append({"role": "system", "content": system_content}) + elif messages and messages[0].get("role") == "system": + converted.append({"role": "system", "content": messages[0]["content"]}) + messages = messages[1:] + else: + converted.append({ + "role": "system", + "content": "You are a helpful AI assistant named SmolLM, trained by Hugging Face." + }) + + for msg in messages: + role = msg.get("role", "user") + + if role == "user": + converted.append({"role": "user", "content": msg["content"]}) + + elif role == "assistant": + content = msg.get("content") or "" + tool_calls = msg.get("tool_calls") + if tool_calls: + tc_text = render_tool_calls(tool_calls) + full_content = f"{content}\n{tc_text}" if content else tc_text + converted.append({"role": "assistant", "content": full_content}) + else: + converted.append({"role": "assistant", "content": content}) + + elif role == "tool": + # Tool responses become user messages with eni/eni_result tokens + content = msg.get("content", "") + if isinstance(content, list): + content = " ".join(c.get("text", "") for c in content if isinstance(c, dict)) + converted.append({ + "role": "user", + "content": render_tool_response(str(content)) + }) + + return converted + + +def load_multiturn_dataset() -> list[dict]: + """Load interstellarninja/tool-calls-multiturn.""" + print("Loading interstellarninja/tool-calls-multiturn ...") + ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train") + samples = [] + for row in ds: + messages = row.get("messages", []) + tools = row.get("tools") + if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"): + continue # skip conversations with no tool calls + converted = convert_openai_messages(messages, tools) + samples.append({"messages": converted}) + print(f" → {len(samples)} samples with tool calls") + return samples + + +def load_hermes_fc_dataset() -> list[dict]: + """Load NousResearch/Hermes-Function-Calling-V1.""" + print("Loading NousResearch/Hermes-Function-Calling-V1 ...") + ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train") + samples = [] + for row in ds: + messages = row.get("messages", []) + tools = row.get("tools") + if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"): + continue + converted = convert_openai_messages(messages, tools) + samples.append({"messages": converted}) + print(f" → {len(samples)} samples with tool calls") + return samples + + +def load_xlam_dataset() -> list[dict]: + """Load Salesforce/xLAM-function-calling-60k. + + This dataset uses a different format: each row has 'tools', 'instruction', and 'outputs'. + We convert to conversation format. + """ + print("Loading Salesforce/xLAM-function-calling-60k ...") + ds = load_dataset("Salesforce/xLAM-function-calling-60k", split="train") + samples = [] + for row in ds: + tools_raw = row.get("tools", "[]") + instruction = row.get("instruction", "") + outputs = row.get("answers", row.get("outputs", "")) + + if not instruction or not outputs: + continue + + try: + tools_list = json.loads(tools_raw) if isinstance(tools_raw, str) else tools_raw + except json.JSONDecodeError: + continue + + if not tools_list: + continue + + # Parse the model output — may contain one or more tool calls + try: + output_parsed = json.loads(outputs) if isinstance(outputs, str) else outputs + except json.JSONDecodeError: + continue + + # Build messages + messages = [{"role": "user", "content": instruction}] + + if isinstance(output_parsed, list): + # Multiple tool calls + tool_calls = [] + for item in output_parsed: + if isinstance(item, dict) and "name" in item: + tool_calls.append({ + "function": { + "name": item["name"], + "arguments": item.get("arguments", item.get("parameters", {})) + } + }) + if tool_calls: + messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""}) + elif isinstance(output_parsed, dict) and "name" in output_parsed: + messages.append({ + "role": "assistant", + "tool_calls": [{ + "function": { + "name": output_parsed["name"], + "arguments": output_parsed.get("arguments", output_parsed.get("parameters", {})) + } + }], + "content": "" + }) + else: + continue + + converted = convert_openai_messages(messages, tools_list) + samples.append({"messages": converted}) + + print(f" → {len(samples)} samples with tool calls") + return samples + + +def tokenize_sample(sample: dict, tokenizer) -> dict | None: + """Tokenize a sample using the model's chat template. + + Returns dict with input_ids, attention_mask, labels (with system/user masked to -100). + """ + messages = sample["messages"] + try: + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + enc = tokenizer(text, truncation=True, max_length=4096) + except Exception as e: + print(f" ⚠ Tokenization failed: {e}") + return None + + input_ids = enc["input_ids"] + attention_mask = enc["attention_mask"] + + # Build labels: mask system + user tokens, only train on assistant responses + labels = [-100] * len(input_ids) + + # Find assistant turn boundaries in the raw text + # We'll use a simpler approach: decode chunks and find assistant markers + ASSISTANT_START = "<|im_start|>assistant\n" + IM_END = "<|im_end|>" + + # Find all assistant spans in the tokenized text by decoding ranges + text_for_search = text + pos = 0 + while True: + start_idx = text_for_search.find(ASSISTANT_START, pos) + if start_idx == -1: + break + end_idx = text_for_search.find(IM_END, start_idx + len(ASSISTANT_START)) + if end_idx == -1: + end_idx = len(text_for_search) + + # Map character offsets to token offsets + # Approximate: count characters up to start/end, find token boundaries + char_to_start = start_idx + len(ASSISTANT_START) # skip the marker itself + char_to_end = end_idx + len(IM_END) + + # Use tokenizer offset mapping if available + enc_with_offsets = tokenizer(text, truncation=True, max_length=4096, return_offsets_mapping=True) + offsets = enc_with_offsets.get("offset_mapping", None) + + if offsets: + tok_start = None + tok_end = None + for ti, (cs, ce) in enumerate(offsets): + if cs >= char_to_start and tok_start is None: + tok_start = ti + if ce >= char_to_end: + tok_end = ti + 1 + break + if tok_start is not None and tok_end is not None: + for i in range(tok_start, min(tok_end, len(labels))): + labels[i] = input_ids[i] + + pos = end_idx + 1 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=str, default="/data/processed") + parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)") + parser.add_argument("--tokenize", action="store_true", help="Also produce tokenized versions") + parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load all datasets + all_samples = [] + all_samples.extend(load_multiturn_dataset()) + all_samples.extend(load_hermes_fc_dataset()) + all_samples.extend(load_xlam_dataset()) + + print(f"\nTotal raw samples: {len(all_samples)}") + + # Shuffle & split + random.seed(SEED) + random.shuffle(all_samples) + + if args.max_samples > 0: + all_samples = all_samples[:args.max_samples] + + val_count = max(1, int(len(all_samples) * VAL_FRACTION)) + val_samples = all_samples[:val_count] + train_samples = all_samples[val_count:] + + print(f"Train: {len(train_samples)}, Val: {len(val_samples)}") + + # Write raw JSONL + for split_name, split_data in [("train", train_samples), ("val", val_samples)]: + path = output_dir / f"{split_name}.jsonl" + with open(path, "w") as f: + for s in split_data: + f.write(json.dumps(s, ensure_ascii=False) + "\n") + print(f"Wrote {path}") + + # Optionally tokenize + if args.tokenize: + print(f"\nTokenizing with {args.model} ...") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model) + + for split_name, split_data in [("train", train_samples), ("val", val_samples)]: + tok_path = output_dir / f"{split_name}_tokenized.jsonl" + count = 0 + with open(tok_path, "w") as f: + for s in split_data: + tok = tokenize_sample(s, tokenizer) + if tok: + f.write(json.dumps(tok) + "\n") + count += 1 + print(f"Wrote {tok_path} ({count} samples)") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c27d37b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch>=2.5.0 +transformers>=4.46.0 +peft>=0.13.0 +accelerate>=1.1.0 +datasets>=3.0.0 +bitsandbytes>=0.44.0 +scipy diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..f9c3afa --- /dev/null +++ b/run.sh @@ -0,0 +1,55 @@ +#!/bin/bash +set -euo pipefail + +# ─── SmolLM3-3B LoRA Training Pipeline ─── +# Stages: prepare → train (skip with env vars) +# +# Usage: +# docker run --gpus all -v /path/to/output:/data smollora +# docker run --gpus all -e SKIP_PREP=1 -v /path/to/processed:/data/processed smollora + +MODEL="${MODEL:-HuggingFaceTB/SmolLM3-3B}" +DATA_DIR="${DATA_DIR:-/data/processed}" +OUTPUT_DIR="${OUTPUT_DIR:-/data/lora-output}" +EPOCHS="${EPOCHS:-3}" +BATCH_SIZE="${BATCH_SIZE:-4}" +LR="${LR:-2e-4}" +LORA_R="${LORA_R:-16}" +MAX_LENGTH="${MAX_LENGTH:-4096}" + +echo "🎭 SmolLM3-3B LoRA Training Pipeline" +echo " Model: $MODEL" +echo " Data: $DATA_DIR" +echo " Output: $OUTPUT_DIR" +echo " Epochs: $EPOCHS" +echo " Batch: $BATCH_SIZE" +echo " LR: $LR" +echo " LoRA r: $LORA_R" +echo "" + +# Stage 1: Data preparation +if [ "${SKIP_PREP:-0}" = "0" ]; then + echo "━━━ Stage 1: Data Preparation ━━━" + python /app/prepare_data.py \ + --output-dir "$DATA_DIR" \ + --model "$MODEL" + echo "✅ Data prepared in $DATA_DIR" +else + echo "⏭ Skipping data preparation (SKIP_PREP=1)" +fi + +# Stage 2: Training +echo "" +echo "━━━ Stage 2: LoRA Training ━━━" +python /app/train_lora.py \ + --data-dir "$DATA_DIR" \ + --model "$MODEL" \ + --output-dir "$OUTPUT_DIR" \ + --epochs "$EPOCHS" \ + --batch-size "$BATCH_SIZE" \ + --lr "$LR" \ + --lora-r "$LORA_R" \ + --max-length "$MAX_LENGTH" + +echo "" +echo "🎭 Training complete! Adapter saved to $OUTPUT_DIR/final" diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..3a75307 --- /dev/null +++ b/train_lora.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +LoRA fine-tuning script for SmolLM3-3B tool-calling. + +Uses PEFT + transformers + accelerate. Designed to run inside the Docker container. + +Usage: + python train_lora.py \ + --data-dir /data/processed \ + --model HuggingFaceTB/SmolLM3-3B \ + --output-dir /data/lora-output \ + --epochs 3 \ + --batch-size 4 \ + --lr 2e-4 +""" + +import argparse +import json +from pathlib import Path + +import torch +from datasets import Dataset +from peft import LoraConfig, TaskType, get_peft_model +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + Trainer, + TrainingArguments, +) + + +def load_jsonl(path: Path) -> list[dict]: + samples = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + samples.append(json.loads(line)) + return samples + + +def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> dict: + """Tokenize a chat-formatted sample and build labels. + + Masks everything except assistant responses (labels = -100 for non-assistant tokens). + """ + messages = sample["messages"] + + # Build the full text using the tokenizer's chat template + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + enc = tokenizer( + text, + truncation=True, + max_length=max_length, + return_offsets_mapping=True, + ) + + input_ids = enc["input_ids"] + attention_mask = enc["attention_mask"] + labels = [-100] * len(input_ids) + + # Find assistant turn boundaries + ASSISTANT_MARKER = "<|im_start|>assistant\n" + END_MARKER = "<|im_end|>" + + offsets = enc.get("offset_mapping", []) + if not offsets: + # Fallback: just train on everything after first assistant turn + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + # Find assistant spans in the raw text + pos = 0 + while True: + start_idx = text.find(ASSISTANT_MARKER, pos) + if start_idx == -1: + break + + # Content starts after the marker + content_start = start_idx + len(ASSISTANT_MARKER) + end_idx = text.find(END_MARKER, content_start) + if end_idx == -1: + span_end = len(text) + else: + span_end = end_idx + len(END_MARKER) + + # Map character offsets to token indices + tok_start = None + tok_end = None + for ti, (cs, ce) in enumerate(offsets): + if cs >= content_start and tok_start is None: + tok_start = ti + if ce >= span_end: + tok_end = ti + 1 + break + + if tok_start is not None and tok_end is not None: + for i in range(tok_start, min(tok_end, len(labels))): + labels[i] = input_ids[i] + + pos = (end_idx if end_idx != -1 else span_end) + 1 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +def main(): + parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling") + parser.add_argument("--data-dir", type=str, default="/data/processed") + parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B") + parser.add_argument("--output-dir", type=str, default="/data/lora-output") + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--grad-accum", type=int, default=4) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--warmup-ratio", type=float, default=0.03) + parser.add_argument("--max-length", type=int, default=4096) + parser.add_argument("--lora-r", type=int, default=16) + parser.add_argument("--lora-alpha", type=int, default=32) + parser.add_argument("--lora-dropout", type=float, default=0.05) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--fp16", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=True) + parser.add_argument("--resume-from", type=str, default=None) + args = parser.parse_args() + + # Load tokenizer + print(f"Loading tokenizer: {args.model}") + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load model + print(f"Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, + trust_remote_code=True, + torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32), + device_map="auto", + ) + + # Configure LoRA + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + bias="none", + ) + + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + # Load data + data_dir = Path(args.data_dir) + train_data = load_jsonl(data_dir / "train.jsonl") + val_data = load_jsonl(data_dir / "val.jsonl") + print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}") + + # Tokenize + print("Tokenizing training data ...") + train_dataset = Dataset.from_list(train_data).map( + lambda x: tokenize_for_training(x, tokenizer, args.max_length), + remove_columns=["messages"], + desc="Tokenizing train", + ) + val_dataset = Dataset.from_list(val_data).map( + lambda x: tokenize_for_training(x, tokenizer, args.max_length), + remove_columns=["messages"], + desc="Tokenizing val", + ) + + # Data collator + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + padding=True, + return_tensors="pt", + ) + + # Training arguments + training_args = TrainingArguments( + output_dir=args.output_dir, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + gradient_accumulation_steps=args.grad_accum, + learning_rate=args.lr, + warmup_ratio=args.warmup_ratio, + lr_scheduler_type="cosine", + logging_steps=10, + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=3, + load_best_model_at_end=True, + metric_for_best_model="eval_loss", + greater_is_better=False, + bf16=args.bf16, + fp16=args.fp16, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + optim="adamw_torch_fused", + seed=args.seed, + report_to="none", + dataloader_num_workers=4, + dataloader_pin_memory=True, + ) + + # Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + data_collator=data_collator, + ) + + # Train + print("Starting training ...") + trainer.train(resume_from_checkpoint=args.resume_from) + + # Save final adapter + print(f"Saving LoRA adapter to {args.output_dir}/final") + model.save_pretrained(f"{args.output_dir}/final") + tokenizer.save_pretrained(f"{args.output_dir}/final") + + # Also save the tokenizer chat template for deployment + print("Done! 🎭") + + +if __name__ == "__main__": + main()