Initial LoRA training setup for SmolLM3-3B tool calling
This commit is contained in:
349
prepare_data.py
Normal file
349
prepare_data.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
|
||||
|
||||
Combines three datasets:
|
||||
1. interstellarninja/tool-calls-multiturn
|
||||
2. NousResearch/Hermes-Function-Calling-V1
|
||||
3. Salesforce/xLAM-function-calling-60k
|
||||
|
||||
Converts all to SmolLM3's native chat format with proper special tokens:
|
||||
- Tool calls wrapped in startPos/endPos tokens (IDs 128002/128016)
|
||||
- Tool responses wrapped in eni/eni_result tokens (IDs 128013/128014)
|
||||
- Thinking wrapped in think_start/think_end tags
|
||||
|
||||
Output: train.jsonl, val.jsonl (tokenized & raw)
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
# SmolLM3 special tokens (match the fixed chat_template.jinja)
|
||||
TOOL_CALL_START = "<|tool_call_start|>" # token 128002
|
||||
TOOL_CALL_END = "<|tool_call_end|>" # token 128016
|
||||
TOOL_RESP_START = "<|tool_response_start|>" # token 128013
|
||||
TOOL_RESP_END = "<|tool_response_end|>" # token 128014
|
||||
THINK_START = "<think>"
|
||||
THINK_END = "</think>"
|
||||
|
||||
VAL_FRACTION = 0.05
|
||||
SEED = 42
|
||||
|
||||
|
||||
def render_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Render tool_calls list into SmolLM3's native format."""
|
||||
parts = []
|
||||
for tc in tool_calls:
|
||||
name = tc["function"]["name"]
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args_str = args
|
||||
else:
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
parts.append(f'{{"name": "{name}", "arguments": {args_str}}}')
|
||||
body = "\n".join(parts)
|
||||
return f"{TOOL_CALL_START}\n{body}\n{TOOL_CALL_END}"
|
||||
|
||||
|
||||
def render_tool_response(content: str) -> str:
|
||||
"""Wrap tool response content in SmolLM3's tool_response tokens."""
|
||||
return f"{TOOL_RESP_START}\n{content}\n{TOOL_RESP_END}"
|
||||
|
||||
|
||||
def convert_openai_messages(messages: list[dict], tools: list[dict] | None = None) -> list[dict]:
|
||||
"""Convert standard OpenAI-format messages to SmolLM3 native format.
|
||||
|
||||
Transforms:
|
||||
- assistant.tool_calls → content with startPos/endPos tokens
|
||||
- tool role messages → user role with eni/eni_result tokens
|
||||
- Adds system prompt with tool definitions if tools present
|
||||
"""
|
||||
converted = []
|
||||
|
||||
# Build system message with tool defs if present
|
||||
if tools:
|
||||
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools)
|
||||
system_content = (
|
||||
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
|
||||
"### Tools\n\n"
|
||||
"You may call one or more functions to assist with the user query.\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n\n"
|
||||
f"<tools>\n{tool_defs}\n</tools>\n\n"
|
||||
'For each function call, return a json object with function name and arguments within '
|
||||
f'{TOOL_CALL_START} {TOOL_CALL_END} tags:\n'
|
||||
f'{TOOL_CALL_START}\n{{"name": <function-name>, "arguments": <args-json-object>}}\n{TOOL_CALL_END}'
|
||||
)
|
||||
converted.append({"role": "system", "content": system_content})
|
||||
elif messages and messages[0].get("role") == "system":
|
||||
converted.append({"role": "system", "content": messages[0]["content"]})
|
||||
messages = messages[1:]
|
||||
else:
|
||||
converted.append({
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
|
||||
})
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
|
||||
if role == "user":
|
||||
converted.append({"role": "user", "content": msg["content"]})
|
||||
|
||||
elif role == "assistant":
|
||||
content = msg.get("content") or ""
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls:
|
||||
tc_text = render_tool_calls(tool_calls)
|
||||
full_content = f"{content}\n{tc_text}" if content else tc_text
|
||||
converted.append({"role": "assistant", "content": full_content})
|
||||
else:
|
||||
converted.append({"role": "assistant", "content": content})
|
||||
|
||||
elif role == "tool":
|
||||
# Tool responses become user messages with eni/eni_result tokens
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(c.get("text", "") for c in content if isinstance(c, dict))
|
||||
converted.append({
|
||||
"role": "user",
|
||||
"content": render_tool_response(str(content))
|
||||
})
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def load_multiturn_dataset() -> list[dict]:
|
||||
"""Load interstellarninja/tool-calls-multiturn."""
|
||||
print("Loading interstellarninja/tool-calls-multiturn ...")
|
||||
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
|
||||
samples = []
|
||||
for row in ds:
|
||||
messages = row.get("messages", [])
|
||||
tools = row.get("tools")
|
||||
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"):
|
||||
continue # skip conversations with no tool calls
|
||||
converted = convert_openai_messages(messages, tools)
|
||||
samples.append({"messages": converted})
|
||||
print(f" → {len(samples)} samples with tool calls")
|
||||
return samples
|
||||
|
||||
|
||||
def load_hermes_fc_dataset() -> list[dict]:
|
||||
"""Load NousResearch/Hermes-Function-Calling-V1."""
|
||||
print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
|
||||
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
|
||||
samples = []
|
||||
for row in ds:
|
||||
messages = row.get("messages", [])
|
||||
tools = row.get("tools")
|
||||
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"):
|
||||
continue
|
||||
converted = convert_openai_messages(messages, tools)
|
||||
samples.append({"messages": converted})
|
||||
print(f" → {len(samples)} samples with tool calls")
|
||||
return samples
|
||||
|
||||
|
||||
def load_xlam_dataset() -> list[dict]:
|
||||
"""Load Salesforce/xLAM-function-calling-60k.
|
||||
|
||||
This dataset uses a different format: each row has 'tools', 'instruction', and 'outputs'.
|
||||
We convert to conversation format.
|
||||
"""
|
||||
print("Loading Salesforce/xLAM-function-calling-60k ...")
|
||||
ds = load_dataset("Salesforce/xLAM-function-calling-60k", split="train")
|
||||
samples = []
|
||||
for row in ds:
|
||||
tools_raw = row.get("tools", "[]")
|
||||
instruction = row.get("instruction", "")
|
||||
outputs = row.get("answers", row.get("outputs", ""))
|
||||
|
||||
if not instruction or not outputs:
|
||||
continue
|
||||
|
||||
try:
|
||||
tools_list = json.loads(tools_raw) if isinstance(tools_raw, str) else tools_raw
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if not tools_list:
|
||||
continue
|
||||
|
||||
# Parse the model output — may contain one or more tool calls
|
||||
try:
|
||||
output_parsed = json.loads(outputs) if isinstance(outputs, str) else outputs
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Build messages
|
||||
messages = [{"role": "user", "content": instruction}]
|
||||
|
||||
if isinstance(output_parsed, list):
|
||||
# Multiple tool calls
|
||||
tool_calls = []
|
||||
for item in output_parsed:
|
||||
if isinstance(item, dict) and "name" in item:
|
||||
tool_calls.append({
|
||||
"function": {
|
||||
"name": item["name"],
|
||||
"arguments": item.get("arguments", item.get("parameters", {}))
|
||||
}
|
||||
})
|
||||
if tool_calls:
|
||||
messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""})
|
||||
elif isinstance(output_parsed, dict) and "name" in output_parsed:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": output_parsed["name"],
|
||||
"arguments": output_parsed.get("arguments", output_parsed.get("parameters", {}))
|
||||
}
|
||||
}],
|
||||
"content": ""
|
||||
})
|
||||
else:
|
||||
continue
|
||||
|
||||
converted = convert_openai_messages(messages, tools_list)
|
||||
samples.append({"messages": converted})
|
||||
|
||||
print(f" → {len(samples)} samples with tool calls")
|
||||
return samples
|
||||
|
||||
|
||||
def tokenize_sample(sample: dict, tokenizer) -> dict | None:
|
||||
"""Tokenize a sample using the model's chat template.
|
||||
|
||||
Returns dict with input_ids, attention_mask, labels (with system/user masked to -100).
|
||||
"""
|
||||
messages = sample["messages"]
|
||||
try:
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
enc = tokenizer(text, truncation=True, max_length=4096)
|
||||
except Exception as e:
|
||||
print(f" ⚠ Tokenization failed: {e}")
|
||||
return None
|
||||
|
||||
input_ids = enc["input_ids"]
|
||||
attention_mask = enc["attention_mask"]
|
||||
|
||||
# Build labels: mask system + user tokens, only train on assistant responses
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
# Find assistant turn boundaries in the raw text
|
||||
# We'll use a simpler approach: decode chunks and find assistant markers
|
||||
ASSISTANT_START = "<|im_start|>assistant\n"
|
||||
IM_END = "<|im_end|>"
|
||||
|
||||
# Find all assistant spans in the tokenized text by decoding ranges
|
||||
text_for_search = text
|
||||
pos = 0
|
||||
while True:
|
||||
start_idx = text_for_search.find(ASSISTANT_START, pos)
|
||||
if start_idx == -1:
|
||||
break
|
||||
end_idx = text_for_search.find(IM_END, start_idx + len(ASSISTANT_START))
|
||||
if end_idx == -1:
|
||||
end_idx = len(text_for_search)
|
||||
|
||||
# Map character offsets to token offsets
|
||||
# Approximate: count characters up to start/end, find token boundaries
|
||||
char_to_start = start_idx + len(ASSISTANT_START) # skip the marker itself
|
||||
char_to_end = end_idx + len(IM_END)
|
||||
|
||||
# Use tokenizer offset mapping if available
|
||||
enc_with_offsets = tokenizer(text, truncation=True, max_length=4096, return_offsets_mapping=True)
|
||||
offsets = enc_with_offsets.get("offset_mapping", None)
|
||||
|
||||
if offsets:
|
||||
tok_start = None
|
||||
tok_end = None
|
||||
for ti, (cs, ce) in enumerate(offsets):
|
||||
if cs >= char_to_start and tok_start is None:
|
||||
tok_start = ti
|
||||
if ce >= char_to_end:
|
||||
tok_end = ti + 1
|
||||
break
|
||||
if tok_start is not None and tok_end is not None:
|
||||
for i in range(tok_start, min(tok_end, len(labels))):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
pos = end_idx + 1
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--output-dir", type=str, default="/data/processed")
|
||||
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
|
||||
parser.add_argument("--tokenize", action="store_true", help="Also produce tokenized versions")
|
||||
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load all datasets
|
||||
all_samples = []
|
||||
all_samples.extend(load_multiturn_dataset())
|
||||
all_samples.extend(load_hermes_fc_dataset())
|
||||
all_samples.extend(load_xlam_dataset())
|
||||
|
||||
print(f"\nTotal raw samples: {len(all_samples)}")
|
||||
|
||||
# Shuffle & split
|
||||
random.seed(SEED)
|
||||
random.shuffle(all_samples)
|
||||
|
||||
if args.max_samples > 0:
|
||||
all_samples = all_samples[:args.max_samples]
|
||||
|
||||
val_count = max(1, int(len(all_samples) * VAL_FRACTION))
|
||||
val_samples = all_samples[:val_count]
|
||||
train_samples = all_samples[val_count:]
|
||||
|
||||
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
|
||||
|
||||
# Write raw JSONL
|
||||
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
||||
path = output_dir / f"{split_name}.jsonl"
|
||||
with open(path, "w") as f:
|
||||
for s in split_data:
|
||||
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||||
print(f"Wrote {path}")
|
||||
|
||||
# Optionally tokenize
|
||||
if args.tokenize:
|
||||
print(f"\nTokenizing with {args.model} ...")
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
|
||||
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
|
||||
tok_path = output_dir / f"{split_name}_tokenized.jsonl"
|
||||
count = 0
|
||||
with open(tok_path, "w") as f:
|
||||
for s in split_data:
|
||||
tok = tokenize_sample(s, tokenizer)
|
||||
if tok:
|
||||
f.write(json.dumps(tok) + "\n")
|
||||
count += 1
|
||||
print(f"Wrote {tok_path} ({count} samples)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user