Initial LoRA training setup for SmolLM3-3B tool calling
This commit is contained in:
22
Dockerfile
Normal file
22
Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
|
||||||
|
|
||||||
|
# System deps
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
git ninja-build packaging wget curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Python deps
|
||||||
|
COPY requirements.txt /tmp/requirements.txt
|
||||||
|
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
||||||
|
|
||||||
|
# Copy scripts
|
||||||
|
WORKDIR /app
|
||||||
|
COPY prepare_data.py /app/
|
||||||
|
COPY train_lora.py /app/
|
||||||
|
COPY run.sh /app/
|
||||||
|
|
||||||
|
# Data and output dirs
|
||||||
|
RUN mkdir -p /data/processed /data/lora-output /data/models
|
||||||
|
|
||||||
|
# Default: run the full pipeline
|
||||||
|
ENTRYPOINT ["/app/run.sh"]
|
||||||
81
README.md
Normal file
81
README.md
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# SmolLM3-3B LoRA — Tool Calling Fine-Tune
|
||||||
|
|
||||||
|
LoRA adapter training to make SmolLM3-3B a tool-calling savant.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build
|
||||||
|
docker build -t smollora .
|
||||||
|
|
||||||
|
# Run full pipeline (prepare data + train)
|
||||||
|
docker run --gpus all \
|
||||||
|
-v /path/on/host/output:/data/lora-output \
|
||||||
|
smollora
|
||||||
|
|
||||||
|
# Skip data prep if you already have processed data
|
||||||
|
docker run --gpus all \
|
||||||
|
-e SKIP_PREP=1 \
|
||||||
|
-v /path/on/host/processed:/data/processed \
|
||||||
|
-v /path/on/host/output:/data/lora-output \
|
||||||
|
smollora
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
| Var | Default | Description |
|
||||||
|
|-----|---------|-------------|
|
||||||
|
| `MODEL` | `HuggingFaceTB/SmolLM3-3B` | Base model (HF repo or local path) |
|
||||||
|
| `DATA_DIR` | `/data/processed` | Processed data directory |
|
||||||
|
| `OUTPUT_DIR` | `/data/lora-output` | Training output directory |
|
||||||
|
| `EPOCHS` | `3` | Training epochs |
|
||||||
|
| `BATCH_SIZE` | `4` | Per-device batch size |
|
||||||
|
| `LR` | `2e-4` | Learning rate |
|
||||||
|
| `LORA_R` | `16` | LoRA rank |
|
||||||
|
| `MAX_LENGTH` | `4096` | Max sequence length |
|
||||||
|
| `SKIP_PREP` | `0` | Set to `1` to skip data preparation |
|
||||||
|
|
||||||
|
## Datasets
|
||||||
|
|
||||||
|
Three datasets combined and converted to SmolLM3's native token format:
|
||||||
|
|
||||||
|
1. **interstellarninja/tool-calls-multiturn** — Multi-turn tool calling conversations
|
||||||
|
2. **NousResearch/Hermes-Function-Calling-V1** — Hermes-format function calling
|
||||||
|
3. **Salesforce/xLAM-function-calling-60k** — Large-scale function calling (60k samples)
|
||||||
|
|
||||||
|
Only conversations containing tool calls are kept. All are normalized to SmolLM3's special tokens:
|
||||||
|
- Tool calls → `startPos`/`endPos` (token IDs 128002/128016)
|
||||||
|
- Tool responses → `eni`/`eni_result` (token IDs 128013/128014)
|
||||||
|
|
||||||
|
## LoRA Configuration
|
||||||
|
|
||||||
|
- **Rank:** 16
|
||||||
|
- **Alpha:** 32
|
||||||
|
- **Target modules:** q/k/v/o projections + gate/up/down MLP
|
||||||
|
- **Dropout:** 0.05
|
||||||
|
- **Scheduler:** Cosine with 3% warmup
|
||||||
|
- **Optimizer:** AdamW (fused)
|
||||||
|
- **Gradient checkpointing:** Enabled
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
The trained adapter is saved to `$OUTPUT_DIR/final/`. To use with vLLM:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Merge adapter into base model (recommended for vLLM)
|
||||||
|
python -m peft import PeftModel
|
||||||
|
# Or pass the adapter path directly with --enable-lora
|
||||||
|
```
|
||||||
|
|
||||||
|
## SSH Deployment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# On GPU box, after SSH-ing in:
|
||||||
|
docker run --gpus all -v ~/smol-data:/data smollora
|
||||||
|
|
||||||
|
# Or with local model cache:
|
||||||
|
docker run --gpus all \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
-v ~/smol-data:/data \
|
||||||
|
smollora
|
||||||
|
```
|
||||||
102
model-files/chat_template.jinja
Normal file
102
model-files/chat_template.jinja
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
{# ───── defaults ───── #}
|
||||||
|
{%- if enable_thinking is not defined -%}
|
||||||
|
{%- set enable_thinking = true -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{# ───── reasoning mode ───── #}
|
||||||
|
{%- if enable_thinking -%}
|
||||||
|
{%- set reasoning_mode = "/think" -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set reasoning_mode = "/no_think" -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{# ───── header (system message) ───── #}
|
||||||
|
{{- "<|im_start|>system\n" -}}
|
||||||
|
|
||||||
|
{%- if messages[0].role == "system" -%}
|
||||||
|
{%- set system_message = messages[0].content -%}
|
||||||
|
{%- if "/no_think" in system_message -%}
|
||||||
|
{%- set reasoning_mode = "/no_think" -%}
|
||||||
|
{%- elif "/think" in system_message -%}
|
||||||
|
{%- set reasoning_mode = "/think" -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if "/system_override" in system_message -%}
|
||||||
|
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- "## Metadata\n\n" -}}
|
||||||
|
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
||||||
|
{%- set today = strftime_now("%d %B %Y") -%}
|
||||||
|
{{- "Today Date: " ~ today ~ "\n" -}}
|
||||||
|
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
||||||
|
|
||||||
|
{{- "## Custom Instructions\n\n" -}}
|
||||||
|
{%- if custom_instructions -%}
|
||||||
|
{{- custom_instructions + "\n\n" -}}
|
||||||
|
{%- elif reasoning_mode == "/think" -%}
|
||||||
|
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if xml_tools or python_tools or tools -%}
|
||||||
|
{{- "### Tools\n\n" -}}
|
||||||
|
{%- if xml_tools or tools -%}
|
||||||
|
{%- if tools -%}
|
||||||
|
{%- set xml_tools = tools -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
||||||
|
{%- for tool in xml_tools[:] -%}
|
||||||
|
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | tojson) ~ "\n" -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
||||||
|
{{- xml_tool_string -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if python_tools -%}
|
||||||
|
{%- set ns = namespace(python_tool_string="You may call one or more functions as python tools.\n<tools>\n") -%}
|
||||||
|
{%- for tool in python_tools[:] -%}
|
||||||
|
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions." -%}
|
||||||
|
{{- python_tool_string -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- "\n\n" -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- "<|im_end|>\n" -}}
|
||||||
|
|
||||||
|
{# ───── main loop ───── #}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if message.role == "user" -%}
|
||||||
|
{{ "<|im_start|>user\n" + message.content + "<|im_end|>\n" }}
|
||||||
|
{%- elif message.role == "assistant" -%}
|
||||||
|
{% generation %}
|
||||||
|
{%- if message.tool_calls -%}
|
||||||
|
{%- set ns = namespace(tc_text="") -%}
|
||||||
|
{%- for tc in message.tool_calls -%}
|
||||||
|
{%- set ns.tc_text = ns.tc_text ~ "<tool_call>\n{\"name\": \"" ~ tc.function.name ~ "\", \"arguments\": " ~ tc.function.arguments ~ "}\n</tool_call>" -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{ "<|im_start|>assistant\n" ~ (message.content if message.content is string else "") ~ ns.tc_text ~ "<|im_end|>\n" }}
|
||||||
|
{%- else -%}
|
||||||
|
{%- if reasoning_mode == "/think" -%}
|
||||||
|
{{ "<|im_start|>assistant\n<think>\n" ~ (message.content if message.content is string else "") ~ "\n</think><|im_end|>\n" }}
|
||||||
|
{%- else -%}
|
||||||
|
{{ "<|im_start|>assistant\n" ~ (message.content if message.content is string else "") ~ "<|im_end|>\n" }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{% endgeneration %}
|
||||||
|
{%- elif message.role == "tool" -%}
|
||||||
|
{{ "<|im_start|>user\n<tool_response>\n" ~ (message.content if message.content is string else "") ~ "\n</tool_response><|im_end|>\n" }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{# ───── generation prompt ───── #}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{%- if reasoning_mode == "/think" -%}
|
||||||
|
{{ "<|im_start|>assistant\n<think>\n" }}
|
||||||
|
{%- else -%}
|
||||||
|
{{ "<|im_start|>assistant\n" }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
130
model-files/gen_template.py
Normal file
130
model-files/gen_template.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Generate the PRODUCTION fixed chat_template.jinja for SmolLM3-3B.
|
||||||
|
|
||||||
|
v2: Fixed thinking mode direction - /think now opens unga... tags
|
||||||
|
in the generation prompt so the model actually generates reasoning.
|
||||||
|
"""
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
|
||||||
|
|
||||||
|
THINK_S = tok.decode([128002])
|
||||||
|
THINK_E = tok.decode([128003])
|
||||||
|
RESP_S = tok.decode([128013])
|
||||||
|
RESP_E = tok.decode([128014])
|
||||||
|
TC_S = tok.decode([128015])
|
||||||
|
TC_E = tok.decode([128016])
|
||||||
|
|
||||||
|
T = []
|
||||||
|
|
||||||
|
# ─── defaults & system header ───
|
||||||
|
T.append(r"""{# ───── defaults ───── #}
|
||||||
|
{%- if enable_thinking is not defined -%}
|
||||||
|
{%- set enable_thinking = true -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{# ───── reasoning mode ───── #}
|
||||||
|
{%- if enable_thinking -%}
|
||||||
|
{%- set reasoning_mode = "/think" -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set reasoning_mode = "/no_think" -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{# ───── header (system message) ───── #}
|
||||||
|
{{- "<|im_start|>system\n" -}}
|
||||||
|
|
||||||
|
{%- if messages[0].role == "system" -%}
|
||||||
|
{%- set system_message = messages[0].content -%}
|
||||||
|
{%- if "/no_think" in system_message -%}
|
||||||
|
{%- set reasoning_mode = "/no_think" -%}
|
||||||
|
{%- elif "/think" in system_message -%}
|
||||||
|
{%- set reasoning_mode = "/think" -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if "/system_override" in system_message -%}
|
||||||
|
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- "## Metadata\n\n" -}}
|
||||||
|
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
||||||
|
{%- set today = strftime_now("%d %B %Y") -%}
|
||||||
|
{{- "Today Date: " ~ today ~ "\n" -}}
|
||||||
|
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
||||||
|
|
||||||
|
{{- "## Custom Instructions\n\n" -}}
|
||||||
|
{%- if custom_instructions -%}
|
||||||
|
{{- custom_instructions + "\n\n" -}}
|
||||||
|
{%- elif reasoning_mode == "/think" -%}
|
||||||
|
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if xml_tools or python_tools or tools -%}
|
||||||
|
{{- "### Tools\n\n" -}}
|
||||||
|
{%- if xml_tools or tools -%}
|
||||||
|
{%- if tools -%}
|
||||||
|
{%- set xml_tools = tools -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
||||||
|
{%- for tool in xml_tools[:] -%}
|
||||||
|
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | tojson) ~ "\n" -%}
|
||||||
|
{%- endfor -%}""")
|
||||||
|
|
||||||
|
# Tool calling format with special tokens
|
||||||
|
T.append('\n {%- set xml_tool_string = ns.xml_tool_string + "</tools>\\n\\nFor each function call, return a json object with function name and arguments within ' + TC_S + ' XML tags:\\n' + TC_S + '\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n' + TC_E + '" -%}\n')
|
||||||
|
|
||||||
|
T.append(r""" {{- xml_tool_string -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if python_tools -%}
|
||||||
|
{%- set ns = namespace(python_tool_string="You may call one or more functions as python tools.\n<tools>\n") -%}
|
||||||
|
{%- for tool in python_tools[:] -%}
|
||||||
|
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions." -%}
|
||||||
|
{{- python_tool_string -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- "\n\n" -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- "<|im_end|>\n" -}}""")
|
||||||
|
|
||||||
|
# ─── Main loop ───
|
||||||
|
T.append(r"""
|
||||||
|
|
||||||
|
{# ───── main loop ───── #}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if message.role == "user" -%}
|
||||||
|
{{ "<|im_start|>user\n" + message.content + "<|im_end|>\n" }}
|
||||||
|
{%- elif message.role == "assistant" -%}
|
||||||
|
{% generation %}
|
||||||
|
{%- if message.tool_calls -%}""")
|
||||||
|
|
||||||
|
# FIX: Render tool calls with TC_S/TC_E tokens
|
||||||
|
T.append('\n {%- set ns = namespace(tc_text="") -%}\n {%- for tc in message.tool_calls -%}\n {%- set ns.tc_text = ns.tc_text ~ "' + TC_S + '\\n{\\"name\\": \\"" ~ tc.function.name ~ "\\", \\"arguments\\": " ~ tc.function.arguments ~ "}\\n' + TC_E + '" -%}\n {%- endfor -%}\n {{ "<|im_start|>assistant\\n" ~ (message.content if message.content is string else "") ~ ns.tc_text ~ "<|im_end|>\\n" }}\n')
|
||||||
|
|
||||||
|
T.append(r""" {%- else -%}""")
|
||||||
|
|
||||||
|
# FIX v2: /think = think tags, /no_think = plain text (CORRECT direction now)
|
||||||
|
T.append('\n {%- if reasoning_mode == "/think" -%}\n {{ "<|im_start|>assistant\\n' + THINK_S + '\\n" ~ (message.content if message.content is string else "") ~ "\\n' + THINK_E + '<|im_end|>\\n" }}\n {%- else -%}\n {{ "<|im_start|>assistant\\n" ~ (message.content if message.content is string else "") ~ "<|im_end|>\\n" }}\n {%- endif -%}\n')
|
||||||
|
|
||||||
|
T.append(r""" {%- endif -%}
|
||||||
|
{% endgeneration %}""")
|
||||||
|
|
||||||
|
# FIX: Tool role with RESP_S/RESP_E tokens
|
||||||
|
T.append('\n {%- elif message.role == "tool" -%}\n {{ "<|im_start|>user\\n' + RESP_S + '\\n" ~ (message.content if message.content is string else "") ~ "\\n' + RESP_E + '<|im_end|>\\n" }}\n')
|
||||||
|
|
||||||
|
T.append(r""" {%- endif -%}
|
||||||
|
{%- endfor -%}""")
|
||||||
|
|
||||||
|
# ─── Generation prompt ───
|
||||||
|
# FIX v2: /think opens unga... so model generates reasoning, /no_think is bare
|
||||||
|
T.append('\n\n{# ───── generation prompt ───── #}\n{%- if add_generation_prompt -%}\n {%- if reasoning_mode == "/think" -%}\n {{ "<|im_start|>assistant\\n' + THINK_S + '\\n" }}\n {%- else -%}\n {{ "<|im_start|>assistant\\n" }}\n {%- endif -%}\n{%- endif -%}\n')
|
||||||
|
|
||||||
|
template = ''.join(T)
|
||||||
|
|
||||||
|
with open('/root/chat_template.jinja', 'w', encoding='utf-8') as f:
|
||||||
|
f.write(template)
|
||||||
|
|
||||||
|
print("Production template v2 written to /root/chat_template.jinja")
|
||||||
|
print(f"Length: {len(template)} bytes")
|
||||||
349
prepare_data.py
Normal file
349
prepare_data.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
|
||||||
|
|
||||||
|
Combines three datasets:
|
||||||
|
1. interstellarninja/tool-calls-multiturn
|
||||||
|
2. NousResearch/Hermes-Function-Calling-V1
|
||||||
|
3. Salesforce/xLAM-function-calling-60k
|
||||||
|
|
||||||
|
Converts all to SmolLM3's native chat format with proper special tokens:
|
||||||
|
- Tool calls wrapped in startPos/endPos tokens (IDs 128002/128016)
|
||||||
|
- Tool responses wrapped in eni/eni_result tokens (IDs 128013/128014)
|
||||||
|
- Thinking wrapped in think_start/think_end tags
|
||||||
|
|
||||||
|
Output: train.jsonl, val.jsonl (tokenized & raw)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
# SmolLM3 special tokens (match the fixed chat_template.jinja)
|
||||||
|
TOOL_CALL_START = "<|tool_call_start|>" # token 128002
|
||||||
|
TOOL_CALL_END = "<|tool_call_end|>" # token 128016
|
||||||
|
TOOL_RESP_START = "<|tool_response_start|>" # token 128013
|
||||||
|
TOOL_RESP_END = "<|tool_response_end|>" # token 128014
|
||||||
|
THINK_START = "<think>"
|
||||||
|
THINK_END = "</think>"
|
||||||
|
|
||||||
|
VAL_FRACTION = 0.05
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
def render_tool_calls(tool_calls: list[dict]) -> str:
|
||||||
|
"""Render tool_calls list into SmolLM3's native format."""
|
||||||
|
parts = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
name = tc["function"]["name"]
|
||||||
|
args = tc["function"]["arguments"]
|
||||||
|
if isinstance(args, str):
|
||||||
|
args_str = args
|
||||||
|
else:
|
||||||
|
args_str = json.dumps(args, ensure_ascii=False)
|
||||||
|
parts.append(f'{{"name": "{name}", "arguments": {args_str}}}')
|
||||||
|
body = "\n".join(parts)
|
||||||
|
return f"{TOOL_CALL_START}\n{body}\n{TOOL_CALL_END}"
|
||||||
|
|
||||||
|
|
||||||
|
def render_tool_response(content: str) -> str:
|
||||||
|
"""Wrap tool response content in SmolLM3's tool_response tokens."""
|
||||||
|
return f"{TOOL_RESP_START}\n{content}\n{TOOL_RESP_END}"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_messages(messages: list[dict], tools: list[dict] | None = None) -> list[dict]:
|
||||||
|
"""Convert standard OpenAI-format messages to SmolLM3 native format.
|
||||||
|
|
||||||
|
Transforms:
|
||||||
|
- assistant.tool_calls → content with startPos/endPos tokens
|
||||||
|
- tool role messages → user role with eni/eni_result tokens
|
||||||
|
- Adds system prompt with tool definitions if tools present
|
||||||
|
"""
|
||||||
|
converted = []
|
||||||
|
|
||||||
|
# Build system message with tool defs if present
|
||||||
|
if tools:
|
||||||
|
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools)
|
||||||
|
system_content = (
|
||||||
|
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
|
||||||
|
"### Tools\n\n"
|
||||||
|
"You may call one or more functions to assist with the user query.\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n\n"
|
||||||
|
f"<tools>\n{tool_defs}\n</tools>\n\n"
|
||||||
|
'For each function call, return a json object with function name and arguments within '
|
||||||
|
f'{TOOL_CALL_START} {TOOL_CALL_END} tags:\n'
|
||||||
|
f'{TOOL_CALL_START}\n{{"name": <function-name>, "arguments": <args-json-object>}}\n{TOOL_CALL_END}'
|
||||||
|
)
|
||||||
|
converted.append({"role": "system", "content": system_content})
|
||||||
|
elif messages and messages[0].get("role") == "system":
|
||||||
|
converted.append({"role": "system", "content": messages[0]["content"]})
|
||||||
|
messages = messages[1:]
|
||||||
|
else:
|
||||||
|
converted.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
||||||
|
})
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
converted.append({"role": "user", "content": msg["content"]})
|
||||||
|
|
||||||
|
elif role == "assistant":
|
||||||
|
content = msg.get("content") or ""
|
||||||
|
tool_calls = msg.get("tool_calls")
|
||||||
|
if tool_calls:
|
||||||
|
tc_text = render_tool_calls(tool_calls)
|
||||||
|
full_content = f"{content}\n{tc_text}" if content else tc_text
|
||||||
|
converted.append({"role": "assistant", "content": full_content})
|
||||||
|
else:
|
||||||
|
converted.append({"role": "assistant", "content": content})
|
||||||
|
|
||||||
|
elif role == "tool":
|
||||||
|
# Tool responses become user messages with eni/eni_result tokens
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
content = " ".join(c.get("text", "") for c in content if isinstance(c, dict))
|
||||||
|
converted.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": render_tool_response(str(content))
|
||||||
|
})
|
||||||
|
|
||||||
|
return converted
|
||||||
|
|
||||||
|
|
||||||
|
def load_multiturn_dataset() -> list[dict]:
|
||||||
|
"""Load interstellarninja/tool-calls-multiturn."""
|
||||||
|
print("Loading interstellarninja/tool-calls-multiturn ...")
|
||||||
|
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
|
||||||
|
samples = []
|
||||||
|
for row in ds:
|
||||||
|
messages = row.get("messages", [])
|
||||||
|
tools = row.get("tools")
|
||||||
|
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"):
|
||||||
|
continue # skip conversations with no tool calls
|
||||||
|
converted = convert_openai_messages(messages, tools)
|
||||||
|
samples.append({"messages": converted})
|
||||||
|
print(f" → {len(samples)} samples with tool calls")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def load_hermes_fc_dataset() -> list[dict]:
|
||||||
|
"""Load NousResearch/Hermes-Function-Calling-V1."""
|
||||||
|
print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
|
||||||
|
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
|
||||||
|
samples = []
|
||||||
|
for row in ds:
|
||||||
|
messages = row.get("messages", [])
|
||||||
|
tools = row.get("tools")
|
||||||
|
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"):
|
||||||
|
continue
|
||||||
|
converted = convert_openai_messages(messages, tools)
|
||||||
|
samples.append({"messages": converted})
|
||||||
|
print(f" → {len(samples)} samples with tool calls")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def load_xlam_dataset() -> list[dict]:
|
||||||
|
"""Load Salesforce/xLAM-function-calling-60k.
|
||||||
|
|
||||||
|
This dataset uses a different format: each row has 'tools', 'instruction', and 'outputs'.
|
||||||
|
We convert to conversation format.
|
||||||
|
"""
|
||||||
|
print("Loading Salesforce/xLAM-function-calling-60k ...")
|
||||||
|
ds = load_dataset("Salesforce/xLAM-function-calling-60k", split="train")
|
||||||
|
samples = []
|
||||||
|
for row in ds:
|
||||||
|
tools_raw = row.get("tools", "[]")
|
||||||
|
instruction = row.get("instruction", "")
|
||||||
|
outputs = row.get("answers", row.get("outputs", ""))
|
||||||
|
|
||||||
|
if not instruction or not outputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
tools_list = json.loads(tools_raw) if isinstance(tools_raw, str) else tools_raw
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not tools_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse the model output — may contain one or more tool calls
|
||||||
|
try:
|
||||||
|
output_parsed = json.loads(outputs) if isinstance(outputs, str) else outputs
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build messages
|
||||||
|
messages = [{"role": "user", "content": instruction}]
|
||||||
|
|
||||||
|
if isinstance(output_parsed, list):
|
||||||
|
# Multiple tool calls
|
||||||
|
tool_calls = []
|
||||||
|
for item in output_parsed:
|
||||||
|
if isinstance(item, dict) and "name" in item:
|
||||||
|
tool_calls.append({
|
||||||
|
"function": {
|
||||||
|
"name": item["name"],
|
||||||
|
"arguments": item.get("arguments", item.get("parameters", {}))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if tool_calls:
|
||||||
|
messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""})
|
||||||
|
elif isinstance(output_parsed, dict) and "name" in output_parsed:
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {
|
||||||
|
"name": output_parsed["name"],
|
||||||
|
"arguments": output_parsed.get("arguments", output_parsed.get("parameters", {}))
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"content": ""
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
converted = convert_openai_messages(messages, tools_list)
|
||||||
|
samples.append({"messages": converted})
|
||||||
|
|
||||||
|
print(f" → {len(samples)} samples with tool calls")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_sample(sample: dict, tokenizer) -> dict | None:
|
||||||
|
"""Tokenize a sample using the model's chat template.
|
||||||
|
|
||||||
|
Returns dict with input_ids, attention_mask, labels (with system/user masked to -100).
|
||||||
|
"""
|
||||||
|
messages = sample["messages"]
|
||||||
|
try:
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
)
|
||||||
|
enc = tokenizer(text, truncation=True, max_length=4096)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠ Tokenization failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_ids = enc["input_ids"]
|
||||||
|
attention_mask = enc["attention_mask"]
|
||||||
|
|
||||||
|
# Build labels: mask system + user tokens, only train on assistant responses
|
||||||
|
labels = [-100] * len(input_ids)
|
||||||
|
|
||||||
|
# Find assistant turn boundaries in the raw text
|
||||||
|
# We'll use a simpler approach: decode chunks and find assistant markers
|
||||||
|
ASSISTANT_START = "<|im_start|>assistant\n"
|
||||||
|
IM_END = "<|im_end|>"
|
||||||
|
|
||||||
|
# Find all assistant spans in the tokenized text by decoding ranges
|
||||||
|
text_for_search = text
|
||||||
|
pos = 0
|
||||||
|
while True:
|
||||||
|
start_idx = text_for_search.find(ASSISTANT_START, pos)
|
||||||
|
if start_idx == -1:
|
||||||
|
break
|
||||||
|
end_idx = text_for_search.find(IM_END, start_idx + len(ASSISTANT_START))
|
||||||
|
if end_idx == -1:
|
||||||
|
end_idx = len(text_for_search)
|
||||||
|
|
||||||
|
# Map character offsets to token offsets
|
||||||
|
# Approximate: count characters up to start/end, find token boundaries
|
||||||
|
char_to_start = start_idx + len(ASSISTANT_START) # skip the marker itself
|
||||||
|
char_to_end = end_idx + len(IM_END)
|
||||||
|
|
||||||
|
# Use tokenizer offset mapping if available
|
||||||
|
enc_with_offsets = tokenizer(text, truncation=True, max_length=4096, return_offsets_mapping=True)
|
||||||
|
offsets = enc_with_offsets.get("offset_mapping", None)
|
||||||
|
|
||||||
|
if offsets:
|
||||||
|
tok_start = None
|
||||||
|
tok_end = None
|
||||||
|
for ti, (cs, ce) in enumerate(offsets):
|
||||||
|
if cs >= char_to_start and tok_start is None:
|
||||||
|
tok_start = ti
|
||||||
|
if ce >= char_to_end:
|
||||||
|
tok_end = ti + 1
|
||||||
|
break
|
||||||
|
if tok_start is not None and tok_end is not None:
|
||||||
|
for i in range(tok_start, min(tok_end, len(labels))):
|
||||||
|
labels[i] = input_ids[i]
|
||||||
|
|
||||||
|
pos = end_idx + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--output-dir", type=str, default="/data/processed")
|
||||||
|
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
|
||||||
|
parser.add_argument("--tokenize", action="store_true", help="Also produce tokenized versions")
|
||||||
|
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load all datasets
|
||||||
|
all_samples = []
|
||||||
|
all_samples.extend(load_multiturn_dataset())
|
||||||
|
all_samples.extend(load_hermes_fc_dataset())
|
||||||
|
all_samples.extend(load_xlam_dataset())
|
||||||
|
|
||||||
|
print(f"\nTotal raw samples: {len(all_samples)}")
|
||||||
|
|
||||||
|
# Shuffle & split
|
||||||
|
random.seed(SEED)
|
||||||
|
random.shuffle(all_samples)
|
||||||
|
|
||||||
|
if args.max_samples > 0:
|
||||||
|
all_samples = all_samples[:args.max_samples]
|
||||||
|
|
||||||
|
val_count = max(1, int(len(all_samples) * VAL_FRACTION))
|
||||||
|
val_samples = all_samples[:val_count]
|
||||||
|
train_samples = all_samples[val_count:]
|
||||||
|
|
||||||
|
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
|
||||||
|
|
||||||
|
# Write raw JSONL
|
||||||
|
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
||||||
|
path = output_dir / f"{split_name}.jsonl"
|
||||||
|
with open(path, "w") as f:
|
||||||
|
for s in split_data:
|
||||||
|
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||||||
|
print(f"Wrote {path}")
|
||||||
|
|
||||||
|
# Optionally tokenize
|
||||||
|
if args.tokenize:
|
||||||
|
print(f"\nTokenizing with {args.model} ...")
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
|
||||||
|
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
||||||
|
tok_path = output_dir / f"{split_name}_tokenized.jsonl"
|
||||||
|
count = 0
|
||||||
|
with open(tok_path, "w") as f:
|
||||||
|
for s in split_data:
|
||||||
|
tok = tokenize_sample(s, tokenizer)
|
||||||
|
if tok:
|
||||||
|
f.write(json.dumps(tok) + "\n")
|
||||||
|
count += 1
|
||||||
|
print(f"Wrote {tok_path} ({count} samples)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
torch>=2.5.0
|
||||||
|
transformers>=4.46.0
|
||||||
|
peft>=0.13.0
|
||||||
|
accelerate>=1.1.0
|
||||||
|
datasets>=3.0.0
|
||||||
|
bitsandbytes>=0.44.0
|
||||||
|
scipy
|
||||||
55
run.sh
Executable file
55
run.sh
Executable file
@@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ─── SmolLM3-3B LoRA Training Pipeline ───
|
||||||
|
# Stages: prepare → train (skip with env vars)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# docker run --gpus all -v /path/to/output:/data smollora
|
||||||
|
# docker run --gpus all -e SKIP_PREP=1 -v /path/to/processed:/data/processed smollora
|
||||||
|
|
||||||
|
MODEL="${MODEL:-HuggingFaceTB/SmolLM3-3B}"
|
||||||
|
DATA_DIR="${DATA_DIR:-/data/processed}"
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-/data/lora-output}"
|
||||||
|
EPOCHS="${EPOCHS:-3}"
|
||||||
|
BATCH_SIZE="${BATCH_SIZE:-4}"
|
||||||
|
LR="${LR:-2e-4}"
|
||||||
|
LORA_R="${LORA_R:-16}"
|
||||||
|
MAX_LENGTH="${MAX_LENGTH:-4096}"
|
||||||
|
|
||||||
|
echo "🎭 SmolLM3-3B LoRA Training Pipeline"
|
||||||
|
echo " Model: $MODEL"
|
||||||
|
echo " Data: $DATA_DIR"
|
||||||
|
echo " Output: $OUTPUT_DIR"
|
||||||
|
echo " Epochs: $EPOCHS"
|
||||||
|
echo " Batch: $BATCH_SIZE"
|
||||||
|
echo " LR: $LR"
|
||||||
|
echo " LoRA r: $LORA_R"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Stage 1: Data preparation
|
||||||
|
if [ "${SKIP_PREP:-0}" = "0" ]; then
|
||||||
|
echo "━━━ Stage 1: Data Preparation ━━━"
|
||||||
|
python /app/prepare_data.py \
|
||||||
|
--output-dir "$DATA_DIR" \
|
||||||
|
--model "$MODEL"
|
||||||
|
echo "✅ Data prepared in $DATA_DIR"
|
||||||
|
else
|
||||||
|
echo "⏭ Skipping data preparation (SKIP_PREP=1)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Stage 2: Training
|
||||||
|
echo ""
|
||||||
|
echo "━━━ Stage 2: LoRA Training ━━━"
|
||||||
|
python /app/train_lora.py \
|
||||||
|
--data-dir "$DATA_DIR" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--output-dir "$OUTPUT_DIR" \
|
||||||
|
--epochs "$EPOCHS" \
|
||||||
|
--batch-size "$BATCH_SIZE" \
|
||||||
|
--lr "$LR" \
|
||||||
|
--lora-r "$LORA_R" \
|
||||||
|
--max-length "$MAX_LENGTH"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🎭 Training complete! Adapter saved to $OUTPUT_DIR/final"
|
||||||
249
train_lora.py
Normal file
249
train_lora.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
LoRA fine-tuning script for SmolLM3-3B tool-calling.
|
||||||
|
|
||||||
|
Uses PEFT + transformers + accelerate. Designed to run inside the Docker container.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python train_lora.py \
|
||||||
|
--data-dir /data/processed \
|
||||||
|
--model HuggingFaceTB/SmolLM3-3B \
|
||||||
|
--output-dir /data/lora-output \
|
||||||
|
--epochs 3 \
|
||||||
|
--batch-size 4 \
|
||||||
|
--lr 2e-4
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_jsonl(path: Path) -> list[dict]:
|
||||||
|
samples = []
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
samples.append(json.loads(line))
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> dict:
|
||||||
|
"""Tokenize a chat-formatted sample and build labels.
|
||||||
|
|
||||||
|
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
|
||||||
|
"""
|
||||||
|
messages = sample["messages"]
|
||||||
|
|
||||||
|
# Build the full text using the tokenizer's chat template
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = tokenizer(
|
||||||
|
text,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_offsets_mapping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = enc["input_ids"]
|
||||||
|
attention_mask = enc["attention_mask"]
|
||||||
|
labels = [-100] * len(input_ids)
|
||||||
|
|
||||||
|
# Find assistant turn boundaries
|
||||||
|
ASSISTANT_MARKER = "<|im_start|>assistant\n"
|
||||||
|
END_MARKER = "<|im_end|>"
|
||||||
|
|
||||||
|
offsets = enc.get("offset_mapping", [])
|
||||||
|
if not offsets:
|
||||||
|
# Fallback: just train on everything after first assistant turn
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find assistant spans in the raw text
|
||||||
|
pos = 0
|
||||||
|
while True:
|
||||||
|
start_idx = text.find(ASSISTANT_MARKER, pos)
|
||||||
|
if start_idx == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Content starts after the marker
|
||||||
|
content_start = start_idx + len(ASSISTANT_MARKER)
|
||||||
|
end_idx = text.find(END_MARKER, content_start)
|
||||||
|
if end_idx == -1:
|
||||||
|
span_end = len(text)
|
||||||
|
else:
|
||||||
|
span_end = end_idx + len(END_MARKER)
|
||||||
|
|
||||||
|
# Map character offsets to token indices
|
||||||
|
tok_start = None
|
||||||
|
tok_end = None
|
||||||
|
for ti, (cs, ce) in enumerate(offsets):
|
||||||
|
if cs >= content_start and tok_start is None:
|
||||||
|
tok_start = ti
|
||||||
|
if ce >= span_end:
|
||||||
|
tok_end = ti + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if tok_start is not None and tok_end is not None:
|
||||||
|
for i in range(tok_start, min(tok_end, len(labels))):
|
||||||
|
labels[i] = input_ids[i]
|
||||||
|
|
||||||
|
pos = (end_idx if end_idx != -1 else span_end) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
|
||||||
|
parser.add_argument("--data-dir", type=str, default="/data/processed")
|
||||||
|
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
||||||
|
parser.add_argument("--output-dir", type=str, default="/data/lora-output")
|
||||||
|
parser.add_argument("--epochs", type=int, default=3)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=4)
|
||||||
|
parser.add_argument("--grad-accum", type=int, default=4)
|
||||||
|
parser.add_argument("--lr", type=float, default=2e-4)
|
||||||
|
parser.add_argument("--warmup-ratio", type=float, default=0.03)
|
||||||
|
parser.add_argument("--max-length", type=int, default=4096)
|
||||||
|
parser.add_argument("--lora-r", type=int, default=16)
|
||||||
|
parser.add_argument("--lora-alpha", type=int, default=32)
|
||||||
|
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--fp16", action="store_true", default=False)
|
||||||
|
parser.add_argument("--bf16", action="store_true", default=True)
|
||||||
|
parser.add_argument("--resume-from", type=str, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
print(f"Loading tokenizer: {args.model}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
print(f"Loading model: {args.model}")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32),
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure LoRA
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
r=args.lora_r,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_dropout=args.lora_dropout,
|
||||||
|
target_modules=[
|
||||||
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
|
"gate_proj", "up_proj", "down_proj",
|
||||||
|
],
|
||||||
|
bias="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
data_dir = Path(args.data_dir)
|
||||||
|
train_data = load_jsonl(data_dir / "train.jsonl")
|
||||||
|
val_data = load_jsonl(data_dir / "val.jsonl")
|
||||||
|
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
print("Tokenizing training data ...")
|
||||||
|
train_dataset = Dataset.from_list(train_data).map(
|
||||||
|
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
|
||||||
|
remove_columns=["messages"],
|
||||||
|
desc="Tokenizing train",
|
||||||
|
)
|
||||||
|
val_dataset = Dataset.from_list(val_data).map(
|
||||||
|
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
|
||||||
|
remove_columns=["messages"],
|
||||||
|
desc="Tokenizing val",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
num_train_epochs=args.epochs,
|
||||||
|
per_device_train_batch_size=args.batch_size,
|
||||||
|
per_device_eval_batch_size=args.batch_size,
|
||||||
|
gradient_accumulation_steps=args.grad_accum,
|
||||||
|
learning_rate=args.lr,
|
||||||
|
warmup_ratio=args.warmup_ratio,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
logging_steps=10,
|
||||||
|
eval_strategy="steps",
|
||||||
|
eval_steps=100,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=100,
|
||||||
|
save_total_limit=3,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
greater_is_better=False,
|
||||||
|
bf16=args.bf16,
|
||||||
|
fp16=args.fp16,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||||
|
optim="adamw_torch_fused",
|
||||||
|
seed=args.seed,
|
||||||
|
report_to="none",
|
||||||
|
dataloader_num_workers=4,
|
||||||
|
dataloader_pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
print("Starting training ...")
|
||||||
|
trainer.train(resume_from_checkpoint=args.resume_from)
|
||||||
|
|
||||||
|
# Save final adapter
|
||||||
|
print(f"Saving LoRA adapter to {args.output_dir}/final")
|
||||||
|
model.save_pretrained(f"{args.output_dir}/final")
|
||||||
|
tokenizer.save_pretrained(f"{args.output_dir}/final")
|
||||||
|
|
||||||
|
# Also save the tokenizer chat template for deployment
|
||||||
|
print("Done! 🎭")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user