init commit
This commit is contained in:
@@ -2,18 +2,19 @@ FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
|
|||||||
|
|
||||||
# System deps
|
# System deps
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
git ninja-build packaging wget curl \
|
git ninja-build wget curl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Python deps
|
# Python deps
|
||||||
COPY requirements.txt /tmp/requirements.txt
|
COPY requirements.txt /tmp/requirements.txt
|
||||||
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
RUN pip install --no-cache-dir -r /tmp/requirements.txt packaging
|
||||||
|
|
||||||
# Copy scripts
|
# Copy scripts
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY prepare_data.py /app/
|
COPY prepare_data.py /app/
|
||||||
COPY train_lora.py /app/
|
COPY train_lora.py /app/
|
||||||
COPY run.sh /app/
|
COPY run.sh /app/
|
||||||
|
RUN chmod +x /app/run.sh
|
||||||
|
|
||||||
# Data and output dirs
|
# Data and output dirs
|
||||||
RUN mkdir -p /data/processed /data/lora-output /data/models
|
RUN mkdir -p /data/processed /data/lora-output /data/models
|
||||||
|
|||||||
328
prepare_data.py
328
prepare_data.py
@@ -2,311 +2,147 @@
|
|||||||
"""
|
"""
|
||||||
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
|
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
|
||||||
|
|
||||||
Combines three datasets:
|
Datasets:
|
||||||
1. interstellarninja/tool-calls-multiturn
|
1. interstellarninja/tool-calls-multiturn
|
||||||
2. NousResearch/Hermes-Function-Calling-V1
|
2. NousResearch/Hermes-Function-Calling-V1
|
||||||
3. Salesforce/xLAM-function-calling-60k
|
|
||||||
|
|
||||||
Converts all to SmolLM3's native chat format with proper special tokens:
|
Both use ShareGPT format (from/value) with inline tool call tags.
|
||||||
- Tool calls wrapped in startPos/endPos tokens (IDs 128002/128016)
|
We convert to SmolLM3's native token format.
|
||||||
- Tool responses wrapped in eni/eni_result tokens (IDs 128013/128014)
|
|
||||||
- Thinking wrapped in think_start/think_end tags
|
|
||||||
|
|
||||||
Output: train.jsonl, val.jsonl (tokenized & raw)
|
Output: train.jsonl, val.jsonl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
# SmolLM3 special tokens (match the fixed chat_template.jinja)
|
|
||||||
TOOL_CALL_START = "<|tool_call_start|>" # token 128002
|
|
||||||
TOOL_CALL_END = "<|tool_call_end|>" # token 128016
|
|
||||||
TOOL_RESP_START = "<|tool_response_start|>" # token 128013
|
|
||||||
TOOL_RESP_END = "<|tool_response_end|>" # token 128014
|
|
||||||
THINK_START = "<think>"
|
|
||||||
THINK_END = "</think>"
|
|
||||||
|
|
||||||
VAL_FRACTION = 0.05
|
VAL_FRACTION = 0.05
|
||||||
SEED = 42
|
SEED = 42
|
||||||
|
|
||||||
|
# Hermes-style tags (used in the source datasets)
|
||||||
|
TC_OPEN = chr(60) + "tool" + chr(62) # <tool>
|
||||||
|
TC_CLOSE = chr(60) + "/tool" + chr(62) # </tool>
|
||||||
|
TR_OPEN = chr(60) + "tool_response" + chr(62) # <tool_response>
|
||||||
|
TR_CLOSE = chr(60) + "/tool_response" + chr(62) # </tool_response>
|
||||||
|
|
||||||
def render_tool_calls(tool_calls: list[dict]) -> str:
|
# SmolLM3 native tokens
|
||||||
"""Render tool_calls list into SmolLM3's native format."""
|
SMOL_TC_START = "<|tool_call_start|>"
|
||||||
parts = []
|
SMOL_TC_END = "<|tool_call_end|>"
|
||||||
for tc in tool_calls:
|
SMOL_TR_START = "<|tool_response_start|>"
|
||||||
name = tc["function"]["name"]
|
SMOL_TR_END = "<|tool_response_end|>"
|
||||||
args = tc["function"]["arguments"]
|
|
||||||
if isinstance(args, str):
|
|
||||||
args_str = args
|
|
||||||
else:
|
|
||||||
args_str = json.dumps(args, ensure_ascii=False)
|
|
||||||
parts.append(f'{{"name": "{name}", "arguments": {args_str}}}')
|
|
||||||
body = "\n".join(parts)
|
|
||||||
return f"{TOOL_CALL_START}\n{body}\n{TOOL_CALL_END}"
|
|
||||||
|
|
||||||
|
|
||||||
def render_tool_response(content: str) -> str:
|
def convert_sharegpt_to_smollm3(conversations, tools_json=None):
|
||||||
"""Wrap tool response content in SmolLM3's tool_response tokens."""
|
"""Convert ShareGPT-style conversation to SmolLM3 messages."""
|
||||||
return f"{TOOL_RESP_START}\n{content}\n{TOOL_RESP_END}"
|
messages = []
|
||||||
|
|
||||||
|
if tools_json:
|
||||||
|
try:
|
||||||
|
tools_list = json.loads(tools_json) if isinstance(tools_json, str) else tools_json
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
tools_list = None
|
||||||
|
|
||||||
def convert_openai_messages(messages: list[dict], tools: list[dict] | None = None) -> list[dict]:
|
if tools_list:
|
||||||
"""Convert standard OpenAI-format messages to SmolLM3 native format.
|
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools_list)
|
||||||
|
|
||||||
Transforms:
|
|
||||||
- assistant.tool_calls → content with startPos/endPos tokens
|
|
||||||
- tool role messages → user role with eni/eni_result tokens
|
|
||||||
- Adds system prompt with tool definitions if tools present
|
|
||||||
"""
|
|
||||||
converted = []
|
|
||||||
|
|
||||||
# Build system message with tool defs if present
|
|
||||||
if tools:
|
|
||||||
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools)
|
|
||||||
system_content = (
|
system_content = (
|
||||||
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
|
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
|
||||||
"### Tools\n\n"
|
"### Tools\n\n"
|
||||||
"You may call one or more functions to assist with the user query.\n"
|
"You may call one or more functions to assist with the user query.\n"
|
||||||
"You are provided with function signatures within <tools></tools> XML tags:\n\n"
|
"You are provided with function signatures within <tools></tools> XML tags:\n\n"
|
||||||
f"<tools>\n{tool_defs}\n</tools>\n\n"
|
f"<tools>\n{tool_defs}\n</tools>\n\n"
|
||||||
'For each function call, return a json object with function name and arguments within '
|
"For each function call, return a json object with function name and arguments within "
|
||||||
f'{TOOL_CALL_START} {TOOL_CALL_END} tags:\n'
|
f"special tags:\n{SMOL_TC_START}\n"
|
||||||
f'{TOOL_CALL_START}\n{{"name": <function-name>, "arguments": <args-json-object>}}\n{TOOL_CALL_END}'
|
'{{"name": <function-name>, "arguments": <args-json-object>}}\n'
|
||||||
|
f"{SMOL_TC_END}\n"
|
||||||
)
|
)
|
||||||
converted.append({"role": "system", "content": system_content})
|
|
||||||
elif messages and messages[0].get("role") == "system":
|
|
||||||
converted.append({"role": "system", "content": messages[0]["content"]})
|
|
||||||
messages = messages[1:]
|
|
||||||
else:
|
else:
|
||||||
converted.append({
|
system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
|
||||||
})
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "user")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
converted.append({"role": "user", "content": msg["content"]})
|
|
||||||
|
|
||||||
elif role == "assistant":
|
|
||||||
content = msg.get("content") or ""
|
|
||||||
tool_calls = msg.get("tool_calls")
|
|
||||||
if tool_calls:
|
|
||||||
tc_text = render_tool_calls(tool_calls)
|
|
||||||
full_content = f"{content}\n{tc_text}" if content else tc_text
|
|
||||||
converted.append({"role": "assistant", "content": full_content})
|
|
||||||
else:
|
else:
|
||||||
converted.append({"role": "assistant", "content": content})
|
system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
||||||
|
|
||||||
|
messages.append({"role": "system", "content": system_content})
|
||||||
|
|
||||||
|
for turn in conversations:
|
||||||
|
role = turn.get("from", turn.get("role", ""))
|
||||||
|
value = turn.get("value", turn.get("content", ""))
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
continue
|
||||||
|
elif role in ("human", "user"):
|
||||||
|
messages.append({"role": "user", "content": value})
|
||||||
|
elif role in ("assistant", "gpt"):
|
||||||
|
content = value.replace(TC_OPEN, SMOL_TC_START)
|
||||||
|
content = content.replace(TC_CLOSE, SMOL_TC_END)
|
||||||
|
if content.strip():
|
||||||
|
messages.append({"role": "assistant", "content": content})
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
# Tool responses become user messages with eni/eni_result tokens
|
content = value.replace(TR_OPEN, SMOL_TR_START)
|
||||||
content = msg.get("content", "")
|
content = content.replace(TR_CLOSE, SMOL_TR_END)
|
||||||
if isinstance(content, list):
|
messages.append({"role": "user", "content": content})
|
||||||
content = " ".join(c.get("text", "") for c in content if isinstance(c, dict))
|
|
||||||
converted.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": render_tool_response(str(content))
|
|
||||||
})
|
|
||||||
|
|
||||||
return converted
|
has_tool_call = any(
|
||||||
|
SMOL_TC_START in m.get("content", "")
|
||||||
|
for m in messages
|
||||||
|
if m["role"] == "assistant"
|
||||||
|
)
|
||||||
|
if not has_tool_call:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def load_multiturn_dataset() -> list[dict]:
|
def load_multiturn_dataset():
|
||||||
"""Load interstellarninja/tool-calls-multiturn."""
|
|
||||||
print("Loading interstellarninja/tool-calls-multiturn ...")
|
print("Loading interstellarninja/tool-calls-multiturn ...")
|
||||||
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
|
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
|
||||||
samples = []
|
samples = []
|
||||||
for row in ds:
|
for row in ds:
|
||||||
messages = row.get("messages", [])
|
conversations = row.get("conversations", [])
|
||||||
tools = row.get("tools")
|
tools = row.get("tools")
|
||||||
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"):
|
if not conversations:
|
||||||
continue # skip conversations with no tool calls
|
continue
|
||||||
converted = convert_openai_messages(messages, tools)
|
tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
|
||||||
|
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
|
||||||
|
if converted:
|
||||||
samples.append({"messages": converted})
|
samples.append({"messages": converted})
|
||||||
print(f" → {len(samples)} samples with tool calls")
|
print(f" -> {len(samples)} samples with tool calls")
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def load_hermes_fc_dataset() -> list[dict]:
|
def load_hermes_fc_dataset():
|
||||||
"""Load NousResearch/Hermes-Function-Calling-V1."""
|
|
||||||
print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
|
print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
|
||||||
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
|
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
|
||||||
samples = []
|
samples = []
|
||||||
for row in ds:
|
for row in ds:
|
||||||
messages = row.get("messages", [])
|
conversations = row.get("conversations", [])
|
||||||
tools = row.get("tools")
|
tools = row.get("tools")
|
||||||
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"):
|
if not conversations:
|
||||||
continue
|
continue
|
||||||
converted = convert_openai_messages(messages, tools)
|
tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
|
||||||
|
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
|
||||||
|
if converted:
|
||||||
samples.append({"messages": converted})
|
samples.append({"messages": converted})
|
||||||
print(f" → {len(samples)} samples with tool calls")
|
print(f" -> {len(samples)} samples with tool calls")
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def load_xlam_dataset() -> list[dict]:
|
|
||||||
"""Load Salesforce/xLAM-function-calling-60k.
|
|
||||||
|
|
||||||
This dataset uses a different format: each row has 'tools', 'instruction', and 'outputs'.
|
|
||||||
We convert to conversation format.
|
|
||||||
"""
|
|
||||||
print("Loading Salesforce/xLAM-function-calling-60k ...")
|
|
||||||
ds = load_dataset("Salesforce/xLAM-function-calling-60k", split="train")
|
|
||||||
samples = []
|
|
||||||
for row in ds:
|
|
||||||
tools_raw = row.get("tools", "[]")
|
|
||||||
instruction = row.get("instruction", "")
|
|
||||||
outputs = row.get("answers", row.get("outputs", ""))
|
|
||||||
|
|
||||||
if not instruction or not outputs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
tools_list = json.loads(tools_raw) if isinstance(tools_raw, str) else tools_raw
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not tools_list:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Parse the model output — may contain one or more tool calls
|
|
||||||
try:
|
|
||||||
output_parsed = json.loads(outputs) if isinstance(outputs, str) else outputs
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Build messages
|
|
||||||
messages = [{"role": "user", "content": instruction}]
|
|
||||||
|
|
||||||
if isinstance(output_parsed, list):
|
|
||||||
# Multiple tool calls
|
|
||||||
tool_calls = []
|
|
||||||
for item in output_parsed:
|
|
||||||
if isinstance(item, dict) and "name" in item:
|
|
||||||
tool_calls.append({
|
|
||||||
"function": {
|
|
||||||
"name": item["name"],
|
|
||||||
"arguments": item.get("arguments", item.get("parameters", {}))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if tool_calls:
|
|
||||||
messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""})
|
|
||||||
elif isinstance(output_parsed, dict) and "name" in output_parsed:
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [{
|
|
||||||
"function": {
|
|
||||||
"name": output_parsed["name"],
|
|
||||||
"arguments": output_parsed.get("arguments", output_parsed.get("parameters", {}))
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
"content": ""
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
converted = convert_openai_messages(messages, tools_list)
|
|
||||||
samples.append({"messages": converted})
|
|
||||||
|
|
||||||
print(f" → {len(samples)} samples with tool calls")
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_sample(sample: dict, tokenizer) -> dict | None:
|
|
||||||
"""Tokenize a sample using the model's chat template.
|
|
||||||
|
|
||||||
Returns dict with input_ids, attention_mask, labels (with system/user masked to -100).
|
|
||||||
"""
|
|
||||||
messages = sample["messages"]
|
|
||||||
try:
|
|
||||||
text = tokenizer.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=False,
|
|
||||||
)
|
|
||||||
enc = tokenizer(text, truncation=True, max_length=4096)
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ⚠ Tokenization failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
input_ids = enc["input_ids"]
|
|
||||||
attention_mask = enc["attention_mask"]
|
|
||||||
|
|
||||||
# Build labels: mask system + user tokens, only train on assistant responses
|
|
||||||
labels = [-100] * len(input_ids)
|
|
||||||
|
|
||||||
# Find assistant turn boundaries in the raw text
|
|
||||||
# We'll use a simpler approach: decode chunks and find assistant markers
|
|
||||||
ASSISTANT_START = "<|im_start|>assistant\n"
|
|
||||||
IM_END = "<|im_end|>"
|
|
||||||
|
|
||||||
# Find all assistant spans in the tokenized text by decoding ranges
|
|
||||||
text_for_search = text
|
|
||||||
pos = 0
|
|
||||||
while True:
|
|
||||||
start_idx = text_for_search.find(ASSISTANT_START, pos)
|
|
||||||
if start_idx == -1:
|
|
||||||
break
|
|
||||||
end_idx = text_for_search.find(IM_END, start_idx + len(ASSISTANT_START))
|
|
||||||
if end_idx == -1:
|
|
||||||
end_idx = len(text_for_search)
|
|
||||||
|
|
||||||
# Map character offsets to token offsets
|
|
||||||
# Approximate: count characters up to start/end, find token boundaries
|
|
||||||
char_to_start = start_idx + len(ASSISTANT_START) # skip the marker itself
|
|
||||||
char_to_end = end_idx + len(IM_END)
|
|
||||||
|
|
||||||
# Use tokenizer offset mapping if available
|
|
||||||
enc_with_offsets = tokenizer(text, truncation=True, max_length=4096, return_offsets_mapping=True)
|
|
||||||
offsets = enc_with_offsets.get("offset_mapping", None)
|
|
||||||
|
|
||||||
if offsets:
|
|
||||||
tok_start = None
|
|
||||||
tok_end = None
|
|
||||||
for ti, (cs, ce) in enumerate(offsets):
|
|
||||||
if cs >= char_to_start and tok_start is None:
|
|
||||||
tok_start = ti
|
|
||||||
if ce >= char_to_end:
|
|
||||||
tok_end = ti + 1
|
|
||||||
break
|
|
||||||
if tok_start is not None and tok_end is not None:
|
|
||||||
for i in range(tok_start, min(tok_end, len(labels))):
|
|
||||||
labels[i] = input_ids[i]
|
|
||||||
|
|
||||||
pos = end_idx + 1
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--output-dir", type=str, default="/data/processed")
|
parser.add_argument("--output-dir", type=str, default="/data")
|
||||||
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
|
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
|
||||||
parser.add_argument("--tokenize", action="store_true", help="Also produce tokenized versions")
|
|
||||||
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Load all datasets
|
|
||||||
all_samples = []
|
all_samples = []
|
||||||
all_samples.extend(load_multiturn_dataset())
|
all_samples.extend(load_multiturn_dataset())
|
||||||
all_samples.extend(load_hermes_fc_dataset())
|
all_samples.extend(load_hermes_fc_dataset())
|
||||||
all_samples.extend(load_xlam_dataset())
|
|
||||||
|
|
||||||
print(f"\nTotal raw samples: {len(all_samples)}")
|
print(f"\nTotal raw samples: {len(all_samples)}")
|
||||||
|
|
||||||
# Shuffle & split
|
|
||||||
random.seed(SEED)
|
random.seed(SEED)
|
||||||
random.shuffle(all_samples)
|
random.shuffle(all_samples)
|
||||||
|
|
||||||
@@ -319,7 +155,6 @@ def main():
|
|||||||
|
|
||||||
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
|
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
|
||||||
|
|
||||||
# Write raw JSONL
|
|
||||||
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
||||||
path = output_dir / f"{split_name}.jsonl"
|
path = output_dir / f"{split_name}.jsonl"
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
@@ -327,22 +162,7 @@ def main():
|
|||||||
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||||||
print(f"Wrote {path}")
|
print(f"Wrote {path}")
|
||||||
|
|
||||||
# Optionally tokenize
|
print("Data preparation complete!")
|
||||||
if args.tokenize:
|
|
||||||
print(f"\nTokenizing with {args.model} ...")
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
||||||
|
|
||||||
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
|
||||||
tok_path = output_dir / f"{split_name}_tokenized.jsonl"
|
|
||||||
count = 0
|
|
||||||
with open(tok_path, "w") as f:
|
|
||||||
for s in split_data:
|
|
||||||
tok = tokenize_sample(s, tokenizer)
|
|
||||||
if tok:
|
|
||||||
f.write(json.dumps(tok) + "\n")
|
|
||||||
count += 1
|
|
||||||
print(f"Wrote {tok_path} ({count} samples)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -2,13 +2,13 @@
|
|||||||
"""
|
"""
|
||||||
LoRA fine-tuning script for SmolLM3-3B tool-calling.
|
LoRA fine-tuning script for SmolLM3-3B tool-calling.
|
||||||
|
|
||||||
Uses PEFT + transformers + accelerate. Designed to run inside the Docker container.
|
Uses PEFT + transformers + accelerate. Runs inside the Docker container.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python train_lora.py \
|
python train_lora.py \
|
||||||
--data-dir /data/processed \
|
--data-dir /data \
|
||||||
--model HuggingFaceTB/SmolLM3-3B \
|
--model HuggingFaceTB/SmolLM3-3B \
|
||||||
--output-dir /data/lora-output \
|
--output-dir /output \
|
||||||
--epochs 3 \
|
--epochs 3 \
|
||||||
--batch-size 4 \
|
--batch-size 4 \
|
||||||
--lr 2e-4
|
--lr 2e-4
|
||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_jsonl(path: Path) -> list[dict]:
|
def load_jsonl(path):
|
||||||
samples = []
|
samples = []
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@@ -40,14 +40,12 @@ def load_jsonl(path: Path) -> list[dict]:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> dict:
|
def tokenize_for_training(sample, tokenizer, max_length=4096):
|
||||||
"""Tokenize a chat-formatted sample and build labels.
|
"""Tokenize a chat-formatted sample and build labels.
|
||||||
|
|
||||||
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
|
Masks everything except assistant responses (labels = -100 for non-assistant tokens).
|
||||||
"""
|
"""
|
||||||
messages = sample["messages"]
|
messages = sample["messages"]
|
||||||
|
|
||||||
# Build the full text using the tokenizer's chat template
|
|
||||||
text = tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
@@ -65,35 +63,34 @@ def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> di
|
|||||||
attention_mask = enc["attention_mask"]
|
attention_mask = enc["attention_mask"]
|
||||||
labels = [-100] * len(input_ids)
|
labels = [-100] * len(input_ids)
|
||||||
|
|
||||||
# Find assistant turn boundaries
|
ASSISTANT_MARKER = "<|im_start|>assistant"
|
||||||
ASSISTANT_MARKER = "<|im_start|>assistant\n"
|
|
||||||
END_MARKER = "<|im_end|>"
|
END_MARKER = "<|im_end|>"
|
||||||
|
|
||||||
offsets = enc.get("offset_mapping", [])
|
offsets = enc.get("offset_mapping", [])
|
||||||
if not offsets:
|
if not offsets:
|
||||||
# Fallback: just train on everything after first assistant turn
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Find assistant spans in the raw text
|
|
||||||
pos = 0
|
pos = 0
|
||||||
while True:
|
while True:
|
||||||
start_idx = text.find(ASSISTANT_MARKER, pos)
|
start_idx = text.find(ASSISTANT_MARKER, pos)
|
||||||
if start_idx == -1:
|
if start_idx == -1:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Content starts after the marker
|
# Content starts after the marker + newline
|
||||||
content_start = start_idx + len(ASSISTANT_MARKER)
|
content_start = start_idx + len(ASSISTANT_MARKER)
|
||||||
|
if content_start < len(text) and text[content_start] == "\n":
|
||||||
|
content_start += 1
|
||||||
|
|
||||||
end_idx = text.find(END_MARKER, content_start)
|
end_idx = text.find(END_MARKER, content_start)
|
||||||
if end_idx == -1:
|
if end_idx == -1:
|
||||||
span_end = len(text)
|
span_end = len(text)
|
||||||
else:
|
else:
|
||||||
span_end = end_idx + len(END_MARKER)
|
span_end = end_idx + len(END_MARKER)
|
||||||
|
|
||||||
# Map character offsets to token indices
|
|
||||||
tok_start = None
|
tok_start = None
|
||||||
tok_end = None
|
tok_end = None
|
||||||
for ti, (cs, ce) in enumerate(offsets):
|
for ti, (cs, ce) in enumerate(offsets):
|
||||||
@@ -118,9 +115,9 @@ def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> di
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
|
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
|
||||||
parser.add_argument("--data-dir", type=str, default="/data/processed")
|
parser.add_argument("--data-dir", type=str, default="/data")
|
||||||
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
||||||
parser.add_argument("--output-dir", type=str, default="/data/lora-output")
|
parser.add_argument("--output-dir", type=str, default="/output")
|
||||||
parser.add_argument("--epochs", type=int, default=3)
|
parser.add_argument("--epochs", type=int, default=3)
|
||||||
parser.add_argument("--batch-size", type=int, default=4)
|
parser.add_argument("--batch-size", type=int, default=4)
|
||||||
parser.add_argument("--grad-accum", type=int, default=4)
|
parser.add_argument("--grad-accum", type=int, default=4)
|
||||||
@@ -136,13 +133,11 @@ def main():
|
|||||||
parser.add_argument("--resume-from", type=str, default=None)
|
parser.add_argument("--resume-from", type=str, default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load tokenizer
|
|
||||||
print(f"Loading tokenizer: {args.model}")
|
print(f"Loading tokenizer: {args.model}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
# Load model
|
|
||||||
print(f"Loading model: {args.model}")
|
print(f"Loading model: {args.model}")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.model,
|
args.model,
|
||||||
@@ -151,7 +146,6 @@ def main():
|
|||||||
device_map="auto",
|
device_map="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure LoRA
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
r=args.lora_r,
|
r=args.lora_r,
|
||||||
@@ -167,13 +161,11 @@ def main():
|
|||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
# Load data
|
|
||||||
data_dir = Path(args.data_dir)
|
data_dir = Path(args.data_dir)
|
||||||
train_data = load_jsonl(data_dir / "train.jsonl")
|
train_data = load_jsonl(data_dir / "train.jsonl")
|
||||||
val_data = load_jsonl(data_dir / "val.jsonl")
|
val_data = load_jsonl(data_dir / "val.jsonl")
|
||||||
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
|
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
|
||||||
|
|
||||||
# Tokenize
|
|
||||||
print("Tokenizing training data ...")
|
print("Tokenizing training data ...")
|
||||||
train_dataset = Dataset.from_list(train_data).map(
|
train_dataset = Dataset.from_list(train_data).map(
|
||||||
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
|
lambda x: tokenize_for_training(x, tokenizer, args.max_length),
|
||||||
@@ -186,14 +178,12 @@ def main():
|
|||||||
desc="Tokenizing val",
|
desc="Tokenizing val",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Data collator
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training arguments
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
num_train_epochs=args.epochs,
|
num_train_epochs=args.epochs,
|
||||||
@@ -223,7 +213,6 @@ def main():
|
|||||||
dataloader_pin_memory=True,
|
dataloader_pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trainer
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
@@ -232,17 +221,14 @@ def main():
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train
|
|
||||||
print("Starting training ...")
|
print("Starting training ...")
|
||||||
trainer.train(resume_from_checkpoint=args.resume_from)
|
trainer.train(resume_from_checkpoint=args.resume_from)
|
||||||
|
|
||||||
# Save final adapter
|
|
||||||
print(f"Saving LoRA adapter to {args.output_dir}/final")
|
print(f"Saving LoRA adapter to {args.output_dir}/final")
|
||||||
model.save_pretrained(f"{args.output_dir}/final")
|
model.save_pretrained(f"{args.output_dir}/final")
|
||||||
tokenizer.save_pretrained(f"{args.output_dir}/final")
|
tokenizer.save_pretrained(f"{args.output_dir}/final")
|
||||||
|
|
||||||
# Also save the tokenizer chat template for deployment
|
print("Done!")
|
||||||
print("Done! 🎭")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user