170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
|
|
|
|
Datasets:
|
|
1. interstellarninja/tool-calls-multiturn
|
|
2. NousResearch/Hermes-Function-Calling-V1
|
|
|
|
Both use ShareGPT format (from/value) with inline tool call tags.
|
|
We convert to SmolLM3's native token format.
|
|
|
|
Output: train.jsonl, val.jsonl
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
from pathlib import Path
|
|
|
|
from datasets import load_dataset
|
|
|
|
VAL_FRACTION = 0.05
|
|
SEED = 42
|
|
|
|
# Tags used in the source datasets
|
|
TC_OPEN = chr(60) + "tool_call" + chr(62)
|
|
TC_CLOSE = chr(60) + "/tool_call" + chr(62)
|
|
TR_OPEN = chr(60) + "tool_response" + chr(62)
|
|
TR_CLOSE = chr(60) + "/tool_response" + chr(62)
|
|
|
|
# SmolLM3 native tokens
|
|
SMOL_TC_START = "<|tool_call_start|>"
|
|
SMOL_TC_END = "<|tool_call_end|>"
|
|
SMOL_TR_START = "<|tool_response_start|>"
|
|
SMOL_TR_END = "<|tool_response_end|>"
|
|
|
|
|
|
def convert_sharegpt_to_smollm3(conversations, tools_json=None):
|
|
"""Convert ShareGPT-style conversation to SmolLM3 messages."""
|
|
messages = []
|
|
|
|
if tools_json:
|
|
try:
|
|
tools_list = json.loads(tools_json) if isinstance(tools_json, str) else tools_json
|
|
except json.JSONDecodeError:
|
|
tools_list = None
|
|
|
|
if tools_list:
|
|
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools_list)
|
|
system_content = (
|
|
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
|
|
"### Tools\n\n"
|
|
"You may call one or more functions to assist with the user query.\n"
|
|
"You are provided with function signatures within <tools></tools> XML tags:\n\n"
|
|
f"<tools>\n{tool_defs}\n</tools>\n\n"
|
|
"For each function call, return a json object with function name and arguments within "
|
|
f"special tags:\n{SMOL_TC_START}\n"
|
|
'{{"name": <function-name>, "arguments": <args-json-object>}}\n'
|
|
f"{SMOL_TC_END}\n"
|
|
)
|
|
else:
|
|
system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
|
else:
|
|
system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
|
|
|
messages.append({"role": "system", "content": system_content})
|
|
|
|
for turn in conversations:
|
|
role = turn.get("from", turn.get("role", ""))
|
|
value = turn.get("value", turn.get("content", ""))
|
|
|
|
if role == "system":
|
|
continue
|
|
elif role in ("human", "user"):
|
|
messages.append({"role": "user", "content": value})
|
|
elif role in ("assistant", "gpt"):
|
|
content = value.replace(TC_OPEN, SMOL_TC_START)
|
|
content = content.replace(TC_CLOSE, SMOL_TC_END)
|
|
if content.strip():
|
|
messages.append({"role": "assistant", "content": content})
|
|
elif role == "tool":
|
|
content = value.replace(TR_OPEN, SMOL_TR_START)
|
|
content = content.replace(TR_CLOSE, SMOL_TR_END)
|
|
messages.append({"role": "user", "content": content})
|
|
|
|
has_tool_call = any(
|
|
SMOL_TC_START in m.get("content", "")
|
|
for m in messages
|
|
if m["role"] == "assistant"
|
|
)
|
|
if not has_tool_call:
|
|
return None
|
|
|
|
return messages
|
|
|
|
|
|
def load_multiturn_dataset():
|
|
print("Loading interstellarninja/tool-calls-multiturn ...")
|
|
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
|
|
samples = []
|
|
for row in ds:
|
|
conversations = row.get("conversations", [])
|
|
tools = row.get("tools")
|
|
if not conversations:
|
|
continue
|
|
tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
|
|
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
|
|
if converted:
|
|
samples.append({"messages": converted})
|
|
print(f" -> {len(samples)} samples with tool calls")
|
|
return samples
|
|
|
|
|
|
def load_hermes_fc_dataset():
|
|
print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
|
|
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
|
|
samples = []
|
|
for row in ds:
|
|
conversations = row.get("conversations", [])
|
|
tools = row.get("tools")
|
|
if not conversations:
|
|
continue
|
|
tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
|
|
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
|
|
if converted:
|
|
samples.append({"messages": converted})
|
|
print(f" -> {len(samples)} samples with tool calls")
|
|
return samples
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--output-dir", type=str, default="/data")
|
|
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
|
|
args = parser.parse_args()
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
all_samples = []
|
|
all_samples.extend(load_multiturn_dataset())
|
|
all_samples.extend(load_hermes_fc_dataset())
|
|
|
|
print(f"\nTotal raw samples: {len(all_samples)}")
|
|
|
|
random.seed(SEED)
|
|
random.shuffle(all_samples)
|
|
|
|
if args.max_samples > 0:
|
|
all_samples = all_samples[:args.max_samples]
|
|
|
|
val_count = max(1, int(len(all_samples) * VAL_FRACTION))
|
|
val_samples = all_samples[:val_count]
|
|
train_samples = all_samples[val_count:]
|
|
|
|
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
|
|
|
|
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
|
path = output_dir / f"{split_name}.jsonl"
|
|
with open(path, "w") as f:
|
|
for s in split_data:
|
|
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
|
print(f"Wrote {path}")
|
|
|
|
print("Data preparation complete!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|