Add training plan: teach model to emit native tool-call tokens

This commit is contained in:
Jinx
2026-04-10 17:07:28 +00:00
parent d1e8c306e3
commit af497eb16c

178
TRAINING_PLAN.md Normal file
View File

@@ -0,0 +1,178 @@
# LoRA Training Plan — Teaching SmolLM3-3B to Emit Tool-Call Tokens
## The Problem
SmolLM3-3B does not emit native tool-call tokens. When asked to use tools, it writes
Python code that *calls* the tool as a function instead of emitting the structured
token sequences (IDs 128015/128016) that vLLM's Hermes parser expects.
This was verified with a raw token inspector that bypasses all middleware:
| Token ID | Decoded | Purpose |
|----------|---------|---------|
| 128015 | ` startPos` | Tool call start delimiter |
| 128016 | ` endPos` | Tool call end delimiter |
| 128013 | ` eni` | Tool response start delimiter |
| 128014 | ` eni_result` | Tool response end delimiter |
The model knows these tokens exist (they're in its vocabulary and the chat template
references them) but it was never trained to *produce* them. It falls back to the
only behavior it knows: writing code.
## Training Strategy
### Core Principle: Train on Raw Tokens, Not Parsed Output
The training data must contain the **actual token IDs** 128015/128016 in the assistant
response. We cannot use any parser or template that "corrects" the output — that would
mask the problem instead of fixing it.
The `apply_chat_template()` call in `train_lora.py` handles this correctly: when the
messages contain an assistant turn with a tool call, the template renders it using the
special delimiters. The loss mask ensures only the assistant's tokens are trained on.
So the model learns: "when tools are available and the user asks me to use one, I emit
the start token, then JSON, then the end token."
### What the Data Must Look Like
Every training sample must have the tool-call delimiters wrapping JSON in the
assistant turn. The key is that after tokenization, token IDs 128015 and 128016
appear in the `input_ids` array within the assistant's labeled region.
**What we need** (tokenized input_ids contains [128015, ...json..., 128016]):
- Assistant turn with: start delimiter + JSON tool call + end delimiter
**What the base model currently produces** (prose/code tokens, no special tokens):
- Assistant turn with: prose explaining the tool + Python code calling it
The training data must have the first pattern. The loss function only trains on
assistant turns, so the model will learn to emit the special tokens.
### Dataset Choice
**Primary: NousResearch/Hermes-Function-Calling-V1**
Why this one:
1. **Already has tool_call tags in the text** — the assistant responses contain
the standard Hermes tool call format
2. **Large and diverse** — 20k+ tool-calling conversations covering many function types
3. **Multi-turn** — includes tool responses and follow-up turns, so the model also
learns to read the response delimiters and respond to tool results
4. **Clean format** — ShareGPT, easy to convert
The conversion in `prepare_data.py` already does the right thing: it replaces
the Hermes tags with SmolLM3's native token strings. After conversion, the
training data has the actual token IDs in the text. When `apply_chat_template()`
tokenizes this, token 128015 and 128016 end up in the `input_ids`, and since they're
in the assistant turn, the loss mask includes them.
**Drop: interstellarninja/tool-calls-multiturn** for now — it's noisier and has more
formatting inconsistency. We can add it later if we need more volume, but starting
clean is better for this focused training run.
**Skip: Salesforce/xLAM-function-calling-60k** — too large, lots of low-quality
samples. We want quality over quantity for a LoRA.
### Data Mix: Don't Forget Non-Tool Turns
If we train on *only* tool-calling samples, the model may learn to always emit
the start delimiter even when it shouldn't. We need a mix:
- **70% tool-calling samples** (Hermes V1, filtered to only tool-call conversations)
- **30% general instruction-following samples** (from SmolLM3's original training data
or a clean instruct dataset like `HuggingFaceTB/smollm-corpus`)
This teaches the model: "emit the tool-call tokens when tools are available AND the
user's request needs a tool call; respond with normal text otherwise."
### Critical Fix: Verify Token IDs in Training Data
Before training, add a verification step that confirms the processed data actually
contains token IDs 128015 and 128016. If `apply_chat_template()` doesn't render
them correctly, the model won't learn them.
```python
# Verification: tokenize a sample and check for tool-call token IDs
sample = train_data[0]
text = tokenizer.apply_chat_template(sample["messages"], tokenize=False)
ids = tokenizer.encode(text)
assert 128015 in ids, "Tool call start token (128015) not found in training data!"
assert 128016 in ids, "Tool call end token (128016) not found in training data!"
```
If this fails, the `prepare_data.py` conversion isn't working or the tokenizer isn't
recognizing the special tokens. Fix before training.
### Training Parameters
The current defaults are reasonable. Key considerations:
| Param | Value | Rationale |
|-------|-------|-----------|
| lora_r | 16 | Standard for 3B model. Could go to 32 if loss plateaus. |
| lora_alpha | 32 | 2x rank, standard. |
| lr | 2e-4 | Good for LoRA. If loss spikes, drop to 1e-4. |
| epochs | 3 | Start here. Check val loss — if still dropping at epoch 3, go to 5. |
| max_length | 4096 | Enough for tool calls + code content. |
| target_modules | q/k/v/o + gate/up/down + embed_tokens | Full coverage. embed_tokens critical — see below. |
### The embed_tokens Question
Token IDs 128015/128016 are in the vocabulary but the model has never used them in
context. Two options:
1. **Include `embed_tokens` in LoRA target modules** — lets the adapter adjust the
embeddings for these tokens so the model can route to them during generation.
**Recommended for this run.** Add `"embed_tokens"` to `target_modules`.
2. **Don't include it** — the base embeddings are frozen, the model can still produce
these tokens (they're in its vocab) but may not have strong enough signal to route
to them. Risky.
Add to `target_modules` in `train_lora.py`:
```python
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings
],
```
Note: PEFT handles `embed_tokens` with LoRA correctly — it applies the low-rank
adaptation to the embedding matrix without issues.
### Validation: Post-Training Token Emission Test
After training, before deploying, run the trained model through the
`chat-template-debugger` (stage 1) to verify the model now emits 128015/128016:
1. Merge the LoRA adapter into the base model
2. Copy merged model to the GPU server's `chat-template-debugger/models/`
3. Run `stage1_debug.py` with the write_file and save_config prompts
4. **Pass criteria:** token IDs 128015 and 128016 appear in the output, followed by
valid JSON, followed by 128016. No Python code-dumping.
### Failure Modes & Mitigations
| Failure | Symptom | Fix |
|---------|---------|-----|
| Model still code-dumps | No 128015/128016 in output | Check verification step; increase epochs; add embed_tokens to targets |
| Model emits tokens but broken JSON | 128015 present but invalid JSON follows | Add more diverse tool-call samples; increase max_length |
| Model over-fits to tool calls | Emits start delimiter for non-tool queries | Add 30% non-tool instruction data |
| Loss doesn't decrease | Val loss flat or increasing | Check label masking (not all -100); verify data quality; lower LR |
| LoRA can't adjust embeddings | embed_tokens not in target | PEFT supports this — make sure the module name matches exactly |
## Summary
The training data is the key. The existing `prepare_data.py` already converts
Hermes-format tool calls to SmolLM3's native tokens. The chat template renders
them as the actual special token IDs. The loss mask trains only on assistant turns.
So the model will learn: "when I see tools available and the user wants me to use one,
I emit token 128015, then the JSON function call, then token 128016."
The two critical changes from the current setup:
1. **Add `embed_tokens` to LoRA target modules** so the adapter can shape the
embeddings for the tool-call tokens
2. **Add 30% non-tool instruction data** to prevent overfitting to tool calls
After training, validate with the raw token debugger before deploying to vLLM.