init commit

This commit is contained in:
Jinx
2026-04-10 06:24:05 +00:00
parent 46a3ddbb25
commit adbd85366b
3 changed files with 101 additions and 294 deletions

View File

@@ -2,18 +2,19 @@ FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
# System deps # System deps
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
git ninja-build packaging wget curl \ git ninja-build wget curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Python deps # Python deps
COPY requirements.txt /tmp/requirements.txt COPY requirements.txt /tmp/requirements.txt
RUN pip install --no-cache-dir -r /tmp/requirements.txt RUN pip install --no-cache-dir -r /tmp/requirements.txt packaging
# Copy scripts # Copy scripts
WORKDIR /app WORKDIR /app
COPY prepare_data.py /app/ COPY prepare_data.py /app/
COPY train_lora.py /app/ COPY train_lora.py /app/
COPY run.sh /app/ COPY run.sh /app/
RUN chmod +x /app/run.sh
# Data and output dirs # Data and output dirs
RUN mkdir -p /data/processed /data/lora-output /data/models RUN mkdir -p /data/processed /data/lora-output /data/models

View File

@@ -2,311 +2,147 @@
""" """
Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning. Prepare tool-calling training data for SmolLM3-3B LoRA fine-tuning.
Combines three datasets: Datasets:
1. interstellarninja/tool-calls-multiturn 1. interstellarninja/tool-calls-multiturn
2. NousResearch/Hermes-Function-Calling-V1 2. NousResearch/Hermes-Function-Calling-V1
3. Salesforce/xLAM-function-calling-60k
Converts all to SmolLM3's native chat format with proper special tokens: Both use ShareGPT format (from/value) with inline tool call tags.
- Tool calls wrapped in startPos/endPos tokens (IDs 128002/128016) We convert to SmolLM3's native token format.
- Tool responses wrapped in eni/eni_result tokens (IDs 128013/128014)
- Thinking wrapped in think_start/think_end tags
Output: train.jsonl, val.jsonl (tokenized & raw) Output: train.jsonl, val.jsonl
""" """
import json import json
import random import random
import re
from pathlib import Path from pathlib import Path
from datasets import load_dataset from datasets import load_dataset
# SmolLM3 special tokens (match the fixed chat_template.jinja)
TOOL_CALL_START = "<|tool_call_start|>" # token 128002
TOOL_CALL_END = "<|tool_call_end|>" # token 128016
TOOL_RESP_START = "<|tool_response_start|>" # token 128013
TOOL_RESP_END = "<|tool_response_end|>" # token 128014
THINK_START = "<think>"
THINK_END = "</think>"
VAL_FRACTION = 0.05 VAL_FRACTION = 0.05
SEED = 42 SEED = 42
# Hermes-style tags (used in the source datasets)
TC_OPEN = chr(60) + "tool" + chr(62) # <tool>
TC_CLOSE = chr(60) + "/tool" + chr(62) # </tool>
TR_OPEN = chr(60) + "tool_response" + chr(62) # <tool_response>
TR_CLOSE = chr(60) + "/tool_response" + chr(62) # </tool_response>
def render_tool_calls(tool_calls: list[dict]) -> str: # SmolLM3 native tokens
"""Render tool_calls list into SmolLM3's native format.""" SMOL_TC_START = "<|tool_call_start|>"
parts = [] SMOL_TC_END = "<|tool_call_end|>"
for tc in tool_calls: SMOL_TR_START = "<|tool_response_start|>"
name = tc["function"]["name"] SMOL_TR_END = "<|tool_response_end|>"
args = tc["function"]["arguments"]
if isinstance(args, str):
args_str = args
else:
args_str = json.dumps(args, ensure_ascii=False)
parts.append(f'{{"name": "{name}", "arguments": {args_str}}}')
body = "\n".join(parts)
return f"{TOOL_CALL_START}\n{body}\n{TOOL_CALL_END}"
def render_tool_response(content: str) -> str: def convert_sharegpt_to_smollm3(conversations, tools_json=None):
"""Wrap tool response content in SmolLM3's tool_response tokens.""" """Convert ShareGPT-style conversation to SmolLM3 messages."""
return f"{TOOL_RESP_START}\n{content}\n{TOOL_RESP_END}" messages = []
if tools_json:
try:
tools_list = json.loads(tools_json) if isinstance(tools_json, str) else tools_json
except json.JSONDecodeError:
tools_list = None
def convert_openai_messages(messages: list[dict], tools: list[dict] | None = None) -> list[dict]: if tools_list:
"""Convert standard OpenAI-format messages to SmolLM3 native format. tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools_list)
Transforms:
- assistant.tool_calls → content with startPos/endPos tokens
- tool role messages → user role with eni/eni_result tokens
- Adds system prompt with tool definitions if tools present
"""
converted = []
# Build system message with tool defs if present
if tools:
tool_defs = "\n".join(json.dumps(t, ensure_ascii=False) for t in tools)
system_content = ( system_content = (
"You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n"
"### Tools\n\n" "### Tools\n\n"
"You may call one or more functions to assist with the user query.\n" "You may call one or more functions to assist with the user query.\n"
"You are provided with function signatures within <tools></tools> XML tags:\n\n" "You are provided with function signatures within <tools></tools> XML tags:\n\n"
f"<tools>\n{tool_defs}\n</tools>\n\n" f"<tools>\n{tool_defs}\n</tools>\n\n"
'For each function call, return a json object with function name and arguments within ' "For each function call, return a json object with function name and arguments within "
f'{TOOL_CALL_START} {TOOL_CALL_END} tags:\n' f"special tags:\n{SMOL_TC_START}\n"
f'{TOOL_CALL_START}\n{{"name": <function-name>, "arguments": <args-json-object>}}\n{TOOL_CALL_END}' '{{"name": <function-name>, "arguments": <args-json-object>}}\n'
f"{SMOL_TC_END}\n"
) )
converted.append({"role": "system", "content": system_content})
elif messages and messages[0].get("role") == "system":
converted.append({"role": "system", "content": messages[0]["content"]})
messages = messages[1:]
else: else:
converted.append({ system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
"role": "system",
"content": "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
})
for msg in messages:
role = msg.get("role", "user")
if role == "user":
converted.append({"role": "user", "content": msg["content"]})
elif role == "assistant":
content = msg.get("content") or ""
tool_calls = msg.get("tool_calls")
if tool_calls:
tc_text = render_tool_calls(tool_calls)
full_content = f"{content}\n{tc_text}" if content else tc_text
converted.append({"role": "assistant", "content": full_content})
else: else:
converted.append({"role": "assistant", "content": content}) system_content = "You are a helpful AI assistant named SmolLM, trained by Hugging Face."
messages.append({"role": "system", "content": system_content})
for turn in conversations:
role = turn.get("from", turn.get("role", ""))
value = turn.get("value", turn.get("content", ""))
if role == "system":
continue
elif role in ("human", "user"):
messages.append({"role": "user", "content": value})
elif role in ("assistant", "gpt"):
content = value.replace(TC_OPEN, SMOL_TC_START)
content = content.replace(TC_CLOSE, SMOL_TC_END)
if content.strip():
messages.append({"role": "assistant", "content": content})
elif role == "tool": elif role == "tool":
# Tool responses become user messages with eni/eni_result tokens content = value.replace(TR_OPEN, SMOL_TR_START)
content = msg.get("content", "") content = content.replace(TR_CLOSE, SMOL_TR_END)
if isinstance(content, list): messages.append({"role": "user", "content": content})
content = " ".join(c.get("text", "") for c in content if isinstance(c, dict))
converted.append({
"role": "user",
"content": render_tool_response(str(content))
})
return converted has_tool_call = any(
SMOL_TC_START in m.get("content", "")
for m in messages
if m["role"] == "assistant"
)
if not has_tool_call:
return None
return messages
def load_multiturn_dataset() -> list[dict]: def load_multiturn_dataset():
"""Load interstellarninja/tool-calls-multiturn."""
print("Loading interstellarninja/tool-calls-multiturn ...") print("Loading interstellarninja/tool-calls-multiturn ...")
ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train") ds = load_dataset("interstellarninja/tool-calls-multiturn", split="train")
samples = [] samples = []
for row in ds: for row in ds:
messages = row.get("messages", []) conversations = row.get("conversations", [])
tools = row.get("tools") tools = row.get("tools")
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"): if not conversations:
continue # skip conversations with no tool calls continue
converted = convert_openai_messages(messages, tools) tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
if converted:
samples.append({"messages": converted}) samples.append({"messages": converted})
print(f" {len(samples)} samples with tool calls") print(f" -> {len(samples)} samples with tool calls")
return samples return samples
def load_hermes_fc_dataset() -> list[dict]: def load_hermes_fc_dataset():
"""Load NousResearch/Hermes-Function-Calling-V1."""
print("Loading NousResearch/Hermes-Function-Calling-V1 ...") print("Loading NousResearch/Hermes-Function-Calling-V1 ...")
ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train") ds = load_dataset("NousResearch/Hermes-Function-Calling-V1", split="train")
samples = [] samples = []
for row in ds: for row in ds:
messages = row.get("messages", []) conversations = row.get("conversations", [])
tools = row.get("tools") tools = row.get("tools")
if not messages or not any(m.get("tool_calls") for m in messages if m.get("role") == "assistant"): if not conversations:
continue continue
converted = convert_openai_messages(messages, tools) tools_str = tools if isinstance(tools, str) else (json.dumps(tools) if tools else None)
converted = convert_sharegpt_to_smollm3(conversations, tools_str)
if converted:
samples.append({"messages": converted}) samples.append({"messages": converted})
print(f" {len(samples)} samples with tool calls") print(f" -> {len(samples)} samples with tool calls")
return samples return samples
def load_xlam_dataset() -> list[dict]:
"""Load Salesforce/xLAM-function-calling-60k.
This dataset uses a different format: each row has 'tools', 'instruction', and 'outputs'.
We convert to conversation format.
"""
print("Loading Salesforce/xLAM-function-calling-60k ...")
ds = load_dataset("Salesforce/xLAM-function-calling-60k", split="train")
samples = []
for row in ds:
tools_raw = row.get("tools", "[]")
instruction = row.get("instruction", "")
outputs = row.get("answers", row.get("outputs", ""))
if not instruction or not outputs:
continue
try:
tools_list = json.loads(tools_raw) if isinstance(tools_raw, str) else tools_raw
except json.JSONDecodeError:
continue
if not tools_list:
continue
# Parse the model output — may contain one or more tool calls
try:
output_parsed = json.loads(outputs) if isinstance(outputs, str) else outputs
except json.JSONDecodeError:
continue
# Build messages
messages = [{"role": "user", "content": instruction}]
if isinstance(output_parsed, list):
# Multiple tool calls
tool_calls = []
for item in output_parsed:
if isinstance(item, dict) and "name" in item:
tool_calls.append({
"function": {
"name": item["name"],
"arguments": item.get("arguments", item.get("parameters", {}))
}
})
if tool_calls:
messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""})
elif isinstance(output_parsed, dict) and "name" in output_parsed:
messages.append({
"role": "assistant",
"tool_calls": [{
"function": {
"name": output_parsed["name"],
"arguments": output_parsed.get("arguments", output_parsed.get("parameters", {}))
}
}],
"content": ""
})
else:
continue
converted = convert_openai_messages(messages, tools_list)
samples.append({"messages": converted})
print(f"{len(samples)} samples with tool calls")
return samples
def tokenize_sample(sample: dict, tokenizer) -> dict | None:
"""Tokenize a sample using the model's chat template.
Returns dict with input_ids, attention_mask, labels (with system/user masked to -100).
"""
messages = sample["messages"]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
enc = tokenizer(text, truncation=True, max_length=4096)
except Exception as e:
print(f" ⚠ Tokenization failed: {e}")
return None
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
# Build labels: mask system + user tokens, only train on assistant responses
labels = [-100] * len(input_ids)
# Find assistant turn boundaries in the raw text
# We'll use a simpler approach: decode chunks and find assistant markers
ASSISTANT_START = "<|im_start|>assistant\n"
IM_END = "<|im_end|>"
# Find all assistant spans in the tokenized text by decoding ranges
text_for_search = text
pos = 0
while True:
start_idx = text_for_search.find(ASSISTANT_START, pos)
if start_idx == -1:
break
end_idx = text_for_search.find(IM_END, start_idx + len(ASSISTANT_START))
if end_idx == -1:
end_idx = len(text_for_search)
# Map character offsets to token offsets
# Approximate: count characters up to start/end, find token boundaries
char_to_start = start_idx + len(ASSISTANT_START) # skip the marker itself
char_to_end = end_idx + len(IM_END)
# Use tokenizer offset mapping if available
enc_with_offsets = tokenizer(text, truncation=True, max_length=4096, return_offsets_mapping=True)
offsets = enc_with_offsets.get("offset_mapping", None)
if offsets:
tok_start = None
tok_end = None
for ti, (cs, ce) in enumerate(offsets):
if cs >= char_to_start and tok_start is None:
tok_start = ti
if ce >= char_to_end:
tok_end = ti + 1
break
if tok_start is not None and tok_end is not None:
for i in range(tok_start, min(tok_end, len(labels))):
labels[i] = input_ids[i]
pos = end_idx + 1
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def main(): def main():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str, default="/data/processed") parser.add_argument("--output-dir", type=str, default="/data")
parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)") parser.add_argument("--max-samples", type=int, default=0, help="Limit total samples (0=all)")
parser.add_argument("--tokenize", action="store_true", help="Also produce tokenized versions")
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
args = parser.parse_args() args = parser.parse_args()
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Load all datasets
all_samples = [] all_samples = []
all_samples.extend(load_multiturn_dataset()) all_samples.extend(load_multiturn_dataset())
all_samples.extend(load_hermes_fc_dataset()) all_samples.extend(load_hermes_fc_dataset())
all_samples.extend(load_xlam_dataset())
print(f"\nTotal raw samples: {len(all_samples)}") print(f"\nTotal raw samples: {len(all_samples)}")
# Shuffle & split
random.seed(SEED) random.seed(SEED)
random.shuffle(all_samples) random.shuffle(all_samples)
@@ -319,7 +155,6 @@ def main():
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}") print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
# Write raw JSONL
for split_name, split_data in [("train", train_samples), ("val", val_samples)]: for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
path = output_dir / f"{split_name}.jsonl" path = output_dir / f"{split_name}.jsonl"
with open(path, "w") as f: with open(path, "w") as f:
@@ -327,22 +162,7 @@ def main():
f.write(json.dumps(s, ensure_ascii=False) + "\n") f.write(json.dumps(s, ensure_ascii=False) + "\n")
print(f"Wrote {path}") print(f"Wrote {path}")
# Optionally tokenize print("Data preparation complete!")
if args.tokenize:
print(f"\nTokenizing with {args.model} ...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
for split_name, split_data in [("train", train_samples), ("val", val_samples)]:
tok_path = output_dir / f"{split_name}_tokenized.jsonl"
count = 0
with open(tok_path, "w") as f:
for s in split_data:
tok = tokenize_sample(s, tokenizer)
if tok:
f.write(json.dumps(tok) + "\n")
count += 1
print(f"Wrote {tok_path} ({count} samples)")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -2,13 +2,13 @@
""" """
LoRA fine-tuning script for SmolLM3-3B tool-calling. LoRA fine-tuning script for SmolLM3-3B tool-calling.
Uses PEFT + transformers + accelerate. Designed to run inside the Docker container. Uses PEFT + transformers + accelerate. Runs inside the Docker container.
Usage: Usage:
python train_lora.py \ python train_lora.py \
--data-dir /data/processed \ --data-dir /data \
--model HuggingFaceTB/SmolLM3-3B \ --model HuggingFaceTB/SmolLM3-3B \
--output-dir /data/lora-output \ --output-dir /output \
--epochs 3 \ --epochs 3 \
--batch-size 4 \ --batch-size 4 \
--lr 2e-4 --lr 2e-4
@@ -30,7 +30,7 @@ from transformers import (
) )
def load_jsonl(path: Path) -> list[dict]: def load_jsonl(path):
samples = [] samples = []
with open(path) as f: with open(path) as f:
for line in f: for line in f:
@@ -40,14 +40,12 @@ def load_jsonl(path: Path) -> list[dict]:
return samples return samples
def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> dict: def tokenize_for_training(sample, tokenizer, max_length=4096):
"""Tokenize a chat-formatted sample and build labels. """Tokenize a chat-formatted sample and build labels.
Masks everything except assistant responses (labels = -100 for non-assistant tokens). Masks everything except assistant responses (labels = -100 for non-assistant tokens).
""" """
messages = sample["messages"] messages = sample["messages"]
# Build the full text using the tokenizer's chat template
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
@@ -65,35 +63,34 @@ def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> di
attention_mask = enc["attention_mask"] attention_mask = enc["attention_mask"]
labels = [-100] * len(input_ids) labels = [-100] * len(input_ids)
# Find assistant turn boundaries ASSISTANT_MARKER = "<|im_start|>assistant"
ASSISTANT_MARKER = "<|im_start|>assistant\n"
END_MARKER = "<|im_end|>" END_MARKER = "<|im_end|>"
offsets = enc.get("offset_mapping", []) offsets = enc.get("offset_mapping", [])
if not offsets: if not offsets:
# Fallback: just train on everything after first assistant turn
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"labels": labels, "labels": labels,
} }
# Find assistant spans in the raw text
pos = 0 pos = 0
while True: while True:
start_idx = text.find(ASSISTANT_MARKER, pos) start_idx = text.find(ASSISTANT_MARKER, pos)
if start_idx == -1: if start_idx == -1:
break break
# Content starts after the marker # Content starts after the marker + newline
content_start = start_idx + len(ASSISTANT_MARKER) content_start = start_idx + len(ASSISTANT_MARKER)
if content_start < len(text) and text[content_start] == "\n":
content_start += 1
end_idx = text.find(END_MARKER, content_start) end_idx = text.find(END_MARKER, content_start)
if end_idx == -1: if end_idx == -1:
span_end = len(text) span_end = len(text)
else: else:
span_end = end_idx + len(END_MARKER) span_end = end_idx + len(END_MARKER)
# Map character offsets to token indices
tok_start = None tok_start = None
tok_end = None tok_end = None
for ti, (cs, ce) in enumerate(offsets): for ti, (cs, ce) in enumerate(offsets):
@@ -118,9 +115,9 @@ def tokenize_for_training(sample: dict, tokenizer, max_length: int = 4096) -> di
def main(): def main():
parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling") parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B for tool calling")
parser.add_argument("--data-dir", type=str, default="/data/processed") parser.add_argument("--data-dir", type=str, default="/data")
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B") parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B")
parser.add_argument("--output-dir", type=str, default="/data/lora-output") parser.add_argument("--output-dir", type=str, default="/output")
parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--grad-accum", type=int, default=4)
@@ -136,13 +133,11 @@ def main():
parser.add_argument("--resume-from", type=str, default=None) parser.add_argument("--resume-from", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
# Load tokenizer
print(f"Loading tokenizer: {args.model}") print(f"Loading tokenizer: {args.model}")
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Load model
print(f"Loading model: {args.model}") print(f"Loading model: {args.model}")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.model, args.model,
@@ -151,7 +146,6 @@ def main():
device_map="auto", device_map="auto",
) )
# Configure LoRA
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
r=args.lora_r, r=args.lora_r,
@@ -167,13 +161,11 @@ def main():
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
model.print_trainable_parameters() model.print_trainable_parameters()
# Load data
data_dir = Path(args.data_dir) data_dir = Path(args.data_dir)
train_data = load_jsonl(data_dir / "train.jsonl") train_data = load_jsonl(data_dir / "train.jsonl")
val_data = load_jsonl(data_dir / "val.jsonl") val_data = load_jsonl(data_dir / "val.jsonl")
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}") print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
# Tokenize
print("Tokenizing training data ...") print("Tokenizing training data ...")
train_dataset = Dataset.from_list(train_data).map( train_dataset = Dataset.from_list(train_data).map(
lambda x: tokenize_for_training(x, tokenizer, args.max_length), lambda x: tokenize_for_training(x, tokenizer, args.max_length),
@@ -186,14 +178,12 @@ def main():
desc="Tokenizing val", desc="Tokenizing val",
) )
# Data collator
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
) )
# Training arguments
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=args.output_dir, output_dir=args.output_dir,
num_train_epochs=args.epochs, num_train_epochs=args.epochs,
@@ -223,7 +213,6 @@ def main():
dataloader_pin_memory=True, dataloader_pin_memory=True,
) )
# Trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
@@ -232,17 +221,14 @@ def main():
data_collator=data_collator, data_collator=data_collator,
) )
# Train
print("Starting training ...") print("Starting training ...")
trainer.train(resume_from_checkpoint=args.resume_from) trainer.train(resume_from_checkpoint=args.resume_from)
# Save final adapter
print(f"Saving LoRA adapter to {args.output_dir}/final") print(f"Saving LoRA adapter to {args.output_dir}/final")
model.save_pretrained(f"{args.output_dir}/final") model.save_pretrained(f"{args.output_dir}/final")
tokenizer.save_pretrained(f"{args.output_dir}/final") tokenizer.save_pretrained(f"{args.output_dir}/final")
# Also save the tokenizer chat template for deployment print("Done!")
print("Done! 🎭")
if __name__ == "__main__": if __name__ == "__main__":