Files
smollora/prepare_data.py

170 lines
5.7 KiB
Python

#!/usr/bin/env python3
"""
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
Datasets:
1. interstellarninja/tool-calls-multiturn
2. NousResearch/Hermes-Function-Calling-V1
Both use ShareGPT format (from/value) with inline tool call tags.
We convert to SmolLM3's native token format.
Output: train.jsonl, val.jsonl
"""
import json
import random
from pathlib import Path
from datasets import load_dataset
VAL_FRACTION = 0.05
SEED = 42
# Tags used in the source datasets
TC_OPEN = chr(60) + "tool_call" + chr(62)
TC_CLOSE = chr(60) + "/tool_call" + chr(62)
TR_OPEN = chr(60) + "tool_response" + chr(62)
TR_CLOSE = chr(60) + "/tool_response" + chr(62)
# SmolLM3 native tokens
SMOL_TC_START = "<|tool_call_start|>"
SMOL_TC_END = "<|tool_call_end|>"
SMOL_TR_START = "<|tool_response_start|>"
SMOL_TR_END = "<|tool_response_end|>"
def convert_sharegpt_to_smollm3(conversations, tools_json=None):
"""Convert ShareGPT-style conversation to SmolLM3 messages."""
messages = []
if tools_json:
try:
tools_list = json.loads(tools_json) if isinstance(tools_json, str) else tools_json
except json.JSONDecodeError:
tools_list = None
if tools_list:
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools_list)
system_content = (
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
"### Tools\n\n"
"You may call one or more functions to assist with the user query.\n"
"You are provided with function signatures within <tools></tools> XML tags:\n\n"
f"<tools>\n{tool_defs}\n</tools>\n\n"
"For each function call, return a json object with function name and arguments within "
f"special tags:\n{SMOL_TC_START}\n"
'{{"name": <function-name>, "arguments": <args-json-object>}}\n'
f"{SMOL_TC_END}\n"
)
else:
system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
else:
system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
messages.append({"role": "system", "content": system_content})
for turn in conversations:
role = turn.get("from", turn.get("role", ""))
value = turn.get("value", turn.get("content", ""))
if role == "system":
continue
elif role in ("human", "user"):
messages.append({"role": "user", "content": value})
elif role in ("assistant", "gpt"):
content = value.replace(TC_OPEN, SMOL_TC_START)
content = content.replace(TC_CLOSE, SMOL_TC_END)
if content.strip():
messages.append({"role": "assistant", "content": content})
elif role == "tool":
content = value.replace(TR_OPEN, SMOL_TR_START)
content = content.replace(TR_CLOSE, SMOL_TR_END)
messages.append({"role": "user", "content": content})
has_tool_call = any(
SMOL_TC_START in m.get("content", "")
for m in messages
if m["role"] == "assistant"
)
if not has_tool_call:
return None
return messages
def load_multiturn_dataset():
print("Loading interstellarninja/tool-calls-multiturn ...")
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
samples = []
for row in ds:
conversations = row.get("conversations", [])
tools = row.get("tools")
if not conversations:
continue
tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
if converted:
samples.append({"messages": converted})
print(f" -> {len(samples)} samples with tool calls")
return samples
def load_hermes_fc_dataset():
print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
samples = []
for row in ds:
conversations = row.get("conversations", [])
tools = row.get("tools")
if not conversations:
continue
tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
if converted:
samples.append({"messages": converted})
print(f" -> {len(samples)} samples with tool calls")
return samples
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str, default="/data")
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
all_samples = []
all_samples.extend(load_multiturn_dataset())
all_samples.extend(load_hermes_fc_dataset())
print(f"\nTotal raw samples: {len(all_samples)}")
random.seed(SEED)
random.shuffle(all_samples)
if args.max_samples > 0:
all_samples = all_samples[:args.max_samples]
val_count = max(1, int(len(all_samples) * VAL_FRACTION))
val_samples = all_samples[:val_count]
train_samples = all_samples[val_count:]
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
path = output_dir / f"{split_name}.jsonl"
with open(path, "w") as f:
for s in split_data:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
print(f"Wrote {path}")
print("Data preparation complete!")
if __name__ == "__main__":
main()