Add training plan: teach model to emit native tool-call tokens
This commit is contained in:
178
TRAINING_PLAN.md
Normal file
178
TRAINING_PLAN.md
Normal file
@@ -0,0 +1,178 @@
|
||||
# LoRA Training Plan — Teaching SmolLM3-3B to Emit Tool-Call Tokens
|
||||
|
||||
## The Problem
|
||||
|
||||
SmolLM3-3B does not emit native tool-call tokens. When asked to use tools, it writes
|
||||
Python code that *calls* the tool as a function instead of emitting the structured
|
||||
token sequences (IDs 128015/128016) that vLLM's Hermes parser expects.
|
||||
|
||||
This was verified with a raw token inspector that bypasses all middleware:
|
||||
|
||||
| Token ID | Decoded | Purpose |
|
||||
|----------|---------|---------|
|
||||
| 128015 | ` startPos` | Tool call start delimiter |
|
||||
| 128016 | ` endPos` | Tool call end delimiter |
|
||||
| 128013 | ` eni` | Tool response start delimiter |
|
||||
| 128014 | ` eni_result` | Tool response end delimiter |
|
||||
|
||||
The model knows these tokens exist (they're in its vocabulary and the chat template
|
||||
references them) but it was never trained to *produce* them. It falls back to the
|
||||
only behavior it knows: writing code.
|
||||
|
||||
## Training Strategy
|
||||
|
||||
### Core Principle: Train on Raw Tokens, Not Parsed Output
|
||||
|
||||
The training data must contain the **actual token IDs** 128015/128016 in the assistant
|
||||
response. We cannot use any parser or template that "corrects" the output — that would
|
||||
mask the problem instead of fixing it.
|
||||
|
||||
The `apply_chat_template()` call in `train_lora.py` handles this correctly: when the
|
||||
messages contain an assistant turn with a tool call, the template renders it using the
|
||||
special delimiters. The loss mask ensures only the assistant's tokens are trained on.
|
||||
So the model learns: "when tools are available and the user asks me to use one, I emit
|
||||
the start token, then JSON, then the end token."
|
||||
|
||||
### What the Data Must Look Like
|
||||
|
||||
Every training sample must have the tool-call delimiters wrapping JSON in the
|
||||
assistant turn. The key is that after tokenization, token IDs 128015 and 128016
|
||||
appear in the `input_ids` array within the assistant's labeled region.
|
||||
|
||||
**What we need** (tokenized input_ids contains [128015, ...json..., 128016]):
|
||||
- Assistant turn with: start delimiter + JSON tool call + end delimiter
|
||||
|
||||
**What the base model currently produces** (prose/code tokens, no special tokens):
|
||||
- Assistant turn with: prose explaining the tool + Python code calling it
|
||||
|
||||
The training data must have the first pattern. The loss function only trains on
|
||||
assistant turns, so the model will learn to emit the special tokens.
|
||||
|
||||
### Dataset Choice
|
||||
|
||||
**Primary: NousResearch/Hermes-Function-Calling-V1**
|
||||
|
||||
Why this one:
|
||||
1. **Already has tool_call tags in the text** — the assistant responses contain
|
||||
the standard Hermes tool call format
|
||||
2. **Large and diverse** — 20k+ tool-calling conversations covering many function types
|
||||
3. **Multi-turn** — includes tool responses and follow-up turns, so the model also
|
||||
learns to read the response delimiters and respond to tool results
|
||||
4. **Clean format** — ShareGPT, easy to convert
|
||||
|
||||
The conversion in `prepare_data.py` already does the right thing: it replaces
|
||||
the Hermes tags with SmolLM3's native token strings. After conversion, the
|
||||
training data has the actual token IDs in the text. When `apply_chat_template()`
|
||||
tokenizes this, token 128015 and 128016 end up in the `input_ids`, and since they're
|
||||
in the assistant turn, the loss mask includes them.
|
||||
|
||||
**Drop: interstellarninja/tool-calls-multiturn** for now — it's noisier and has more
|
||||
formatting inconsistency. We can add it later if we need more volume, but starting
|
||||
clean is better for this focused training run.
|
||||
|
||||
**Skip: Salesforce/xLAM-function-calling-60k** — too large, lots of low-quality
|
||||
samples. We want quality over quantity for a LoRA.
|
||||
|
||||
### Data Mix: Don't Forget Non-Tool Turns
|
||||
|
||||
If we train on *only* tool-calling samples, the model may learn to always emit
|
||||
the start delimiter even when it shouldn't. We need a mix:
|
||||
|
||||
- **70% tool-calling samples** (Hermes V1, filtered to only tool-call conversations)
|
||||
- **30% general instruction-following samples** (from SmolLM3's original training data
|
||||
or a clean instruct dataset like `HuggingFaceTB/smollm-corpus`)
|
||||
|
||||
This teaches the model: "emit the tool-call tokens when tools are available AND the
|
||||
user's request needs a tool call; respond with normal text otherwise."
|
||||
|
||||
### Critical Fix: Verify Token IDs in Training Data
|
||||
|
||||
Before training, add a verification step that confirms the processed data actually
|
||||
contains token IDs 128015 and 128016. If `apply_chat_template()` doesn't render
|
||||
them correctly, the model won't learn them.
|
||||
|
||||
```python
|
||||
# Verification: tokenize a sample and check for tool-call token IDs
|
||||
sample = train_data[0]
|
||||
text = tokenizer.apply_chat_template(sample["messages"], tokenize=False)
|
||||
ids = tokenizer.encode(text)
|
||||
assert 128015 in ids, "Tool call start token (128015) not found in training data!"
|
||||
assert 128016 in ids, "Tool call end token (128016) not found in training data!"
|
||||
```
|
||||
|
||||
If this fails, the `prepare_data.py` conversion isn't working or the tokenizer isn't
|
||||
recognizing the special tokens. Fix before training.
|
||||
|
||||
### Training Parameters
|
||||
|
||||
The current defaults are reasonable. Key considerations:
|
||||
|
||||
| Param | Value | Rationale |
|
||||
|-------|-------|-----------|
|
||||
| lora_r | 16 | Standard for 3B model. Could go to 32 if loss plateaus. |
|
||||
| lora_alpha | 32 | 2x rank, standard. |
|
||||
| lr | 2e-4 | Good for LoRA. If loss spikes, drop to 1e-4. |
|
||||
| epochs | 3 | Start here. Check val loss — if still dropping at epoch 3, go to 5. |
|
||||
| max_length | 4096 | Enough for tool calls + code content. |
|
||||
| target_modules | q/k/v/o + gate/up/down + embed_tokens | Full coverage. embed_tokens critical — see below. |
|
||||
|
||||
### The embed_tokens Question
|
||||
|
||||
Token IDs 128015/128016 are in the vocabulary but the model has never used them in
|
||||
context. Two options:
|
||||
|
||||
1. **Include `embed_tokens` in LoRA target modules** — lets the adapter adjust the
|
||||
embeddings for these tokens so the model can route to them during generation.
|
||||
**Recommended for this run.** Add `"embed_tokens"` to `target_modules`.
|
||||
|
||||
2. **Don't include it** — the base embeddings are frozen, the model can still produce
|
||||
these tokens (they're in its vocab) but may not have strong enough signal to route
|
||||
to them. Risky.
|
||||
|
||||
Add to `target_modules` in `train_lora.py`:
|
||||
```python
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
"embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings
|
||||
],
|
||||
```
|
||||
|
||||
Note: PEFT handles `embed_tokens` with LoRA correctly — it applies the low-rank
|
||||
adaptation to the embedding matrix without issues.
|
||||
|
||||
### Validation: Post-Training Token Emission Test
|
||||
|
||||
After training, before deploying, run the trained model through the
|
||||
`chat-template-debugger` (stage 1) to verify the model now emits 128015/128016:
|
||||
|
||||
1. Merge the LoRA adapter into the base model
|
||||
2. Copy merged model to the GPU server's `chat-template-debugger/models/`
|
||||
3. Run `stage1_debug.py` with the write_file and save_config prompts
|
||||
4. **Pass criteria:** token IDs 128015 and 128016 appear in the output, followed by
|
||||
valid JSON, followed by 128016. No Python code-dumping.
|
||||
|
||||
### Failure Modes & Mitigations
|
||||
|
||||
| Failure | Symptom | Fix |
|
||||
|---------|---------|-----|
|
||||
| Model still code-dumps | No 128015/128016 in output | Check verification step; increase epochs; add embed_tokens to targets |
|
||||
| Model emits tokens but broken JSON | 128015 present but invalid JSON follows | Add more diverse tool-call samples; increase max_length |
|
||||
| Model over-fits to tool calls | Emits start delimiter for non-tool queries | Add 30% non-tool instruction data |
|
||||
| Loss doesn't decrease | Val loss flat or increasing | Check label masking (not all -100); verify data quality; lower LR |
|
||||
| LoRA can't adjust embeddings | embed_tokens not in target | PEFT supports this — make sure the module name matches exactly |
|
||||
|
||||
## Summary
|
||||
|
||||
The training data is the key. The existing `prepare_data.py` already converts
|
||||
Hermes-format tool calls to SmolLM3's native tokens. The chat template renders
|
||||
them as the actual special token IDs. The loss mask trains only on assistant turns.
|
||||
So the model will learn: "when I see tools available and the user wants me to use one,
|
||||
I emit token 128015, then the JSON function call, then token 128016."
|
||||
|
||||
The two critical changes from the current setup:
|
||||
1. **Add `embed_tokens` to LoRA target modules** so the adapter can shape the
|
||||
embeddings for the tool-call tokens
|
||||
2. **Add 30% non-tool instruction data** to prevent overfitting to tool calls
|
||||
|
||||
After training, validate with the raw token debugger before deploying to vLLM.
|
||||
Reference in New Issue
Block a user