179 lines
8.2 KiB
Markdown
179 lines
8.2 KiB
Markdown
# LoRA Training Plan — Teaching SmolLM3-3B to Emit Tool-Call Tokens
|
|
|
|
## The Problem
|
|
|
|
SmolLM3-3B does not emit native tool-call tokens. When asked to use tools, it writes
|
|
Python code that *calls* the tool as a function instead of emitting the structured
|
|
token sequences (IDs 128015/128016) that vLLM's Hermes parser expects.
|
|
|
|
This was verified with a raw token inspector that bypasses all middleware:
|
|
|
|
| Token ID | Decoded | Purpose |
|
|
|----------|---------|---------|
|
|
| 128015 | ` startPos` | Tool call start delimiter |
|
|
| 128016 | ` endPos` | Tool call end delimiter |
|
|
| 128013 | ` eni` | Tool response start delimiter |
|
|
| 128014 | ` eni_result` | Tool response end delimiter |
|
|
|
|
The model knows these tokens exist (they're in its vocabulary and the chat template
|
|
references them) but it was never trained to *produce* them. It falls back to the
|
|
only behavior it knows: writing code.
|
|
|
|
## Training Strategy
|
|
|
|
### Core Principle: Train on Raw Tokens, Not Parsed Output
|
|
|
|
The training data must contain the **actual token IDs** 128015/128016 in the assistant
|
|
response. We cannot use any parser or template that "corrects" the output — that would
|
|
mask the problem instead of fixing it.
|
|
|
|
The `apply_chat_template()` call in `train_lora.py` handles this correctly: when the
|
|
messages contain an assistant turn with a tool call, the template renders it using the
|
|
special delimiters. The loss mask ensures only the assistant's tokens are trained on.
|
|
So the model learns: "when tools are available and the user asks me to use one, I emit
|
|
the start token, then JSON, then the end token."
|
|
|
|
### What the Data Must Look Like
|
|
|
|
Every training sample must have the tool-call delimiters wrapping JSON in the
|
|
assistant turn. The key is that after tokenization, token IDs 128015 and 128016
|
|
appear in the `input_ids` array within the assistant's labeled region.
|
|
|
|
**What we need** (tokenized input_ids contains [128015, ...json..., 128016]):
|
|
- Assistant turn with: start delimiter + JSON tool call + end delimiter
|
|
|
|
**What the base model currently produces** (prose/code tokens, no special tokens):
|
|
- Assistant turn with: prose explaining the tool + Python code calling it
|
|
|
|
The training data must have the first pattern. The loss function only trains on
|
|
assistant turns, so the model will learn to emit the special tokens.
|
|
|
|
### Dataset Choice
|
|
|
|
**Primary: NousResearch/Hermes-Function-Calling-V1**
|
|
|
|
Why this one:
|
|
1. **Already has tool_call tags in the text** — the assistant responses contain
|
|
the standard Hermes tool call format
|
|
2. **Large and diverse** — 20k+ tool-calling conversations covering many function types
|
|
3. **Multi-turn** — includes tool responses and follow-up turns, so the model also
|
|
learns to read the response delimiters and respond to tool results
|
|
4. **Clean format** — ShareGPT, easy to convert
|
|
|
|
The conversion in `prepare_data.py` already does the right thing: it replaces
|
|
the Hermes tags with SmolLM3's native token strings. After conversion, the
|
|
training data has the actual token IDs in the text. When `apply_chat_template()`
|
|
tokenizes this, token 128015 and 128016 end up in the `input_ids`, and since they're
|
|
in the assistant turn, the loss mask includes them.
|
|
|
|
**Drop: interstellarninja/tool-calls-multiturn** for now — it's noisier and has more
|
|
formatting inconsistency. We can add it later if we need more volume, but starting
|
|
clean is better for this focused training run.
|
|
|
|
**Skip: Salesforce/xLAM-function-calling-60k** — too large, lots of low-quality
|
|
samples. We want quality over quantity for a LoRA.
|
|
|
|
### Data Mix: Don't Forget Non-Tool Turns
|
|
|
|
If we train on *only* tool-calling samples, the model may learn to always emit
|
|
the start delimiter even when it shouldn't. We need a mix:
|
|
|
|
- **70% tool-calling samples** (Hermes V1, filtered to only tool-call conversations)
|
|
- **30% general instruction-following samples** (from SmolLM3's original training data
|
|
or a clean instruct dataset like `HuggingFaceTB/smollm-corpus`)
|
|
|
|
This teaches the model: "emit the tool-call tokens when tools are available AND the
|
|
user's request needs a tool call; respond with normal text otherwise."
|
|
|
|
### Critical Fix: Verify Token IDs in Training Data
|
|
|
|
Before training, add a verification step that confirms the processed data actually
|
|
contains token IDs 128015 and 128016. If `apply_chat_template()` doesn't render
|
|
them correctly, the model won't learn them.
|
|
|
|
```python
|
|
# Verification: tokenize a sample and check for tool-call token IDs
|
|
sample = train_data[0]
|
|
text = tokenizer.apply_chat_template(sample["messages"], tokenize=False)
|
|
ids = tokenizer.encode(text)
|
|
assert 128015 in ids, "Tool call start token (128015) not found in training data!"
|
|
assert 128016 in ids, "Tool call end token (128016) not found in training data!"
|
|
```
|
|
|
|
If this fails, the `prepare_data.py` conversion isn't working or the tokenizer isn't
|
|
recognizing the special tokens. Fix before training.
|
|
|
|
### Training Parameters
|
|
|
|
The current defaults are reasonable. Key considerations:
|
|
|
|
| Param | Value | Rationale |
|
|
|-------|-------|-----------|
|
|
| lora_r | 16 | Standard for 3B model. Could go to 32 if loss plateaus. |
|
|
| lora_alpha | 32 | 2x rank, standard. |
|
|
| lr | 2e-4 | Good for LoRA. If loss spikes, drop to 1e-4. |
|
|
| epochs | 3 | Start here. Check val loss — if still dropping at epoch 3, go to 5. |
|
|
| max_length | 4096 | Enough for tool calls + code content. |
|
|
| target_modules | q/k/v/o + gate/up/down + embed_tokens | Full coverage. embed_tokens critical — see below. |
|
|
|
|
### The embed_tokens Question
|
|
|
|
Token IDs 128015/128016 are in the vocabulary but the model has never used them in
|
|
context. Two options:
|
|
|
|
1. **Include `embed_tokens` in LoRA target modules** — lets the adapter adjust the
|
|
embeddings for these tokens so the model can route to them during generation.
|
|
**Recommended for this run.** Add `"embed_tokens"` to `target_modules`.
|
|
|
|
2. **Don't include it** — the base embeddings are frozen, the model can still produce
|
|
these tokens (they're in its vocab) but may not have strong enough signal to route
|
|
to them. Risky.
|
|
|
|
Add to `target_modules` in `train_lora.py`:
|
|
```python
|
|
target_modules=[
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj",
|
|
"embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings
|
|
],
|
|
```
|
|
|
|
Note: PEFT handles `embed_tokens` with LoRA correctly — it applies the low-rank
|
|
adaptation to the embedding matrix without issues.
|
|
|
|
### Validation: Post-Training Token Emission Test
|
|
|
|
After training, before deploying, run the trained model through the
|
|
`chat-template-debugger` (stage 1) to verify the model now emits 128015/128016:
|
|
|
|
1. Merge the LoRA adapter into the base model
|
|
2. Copy merged model to the GPU server's `chat-template-debugger/models/`
|
|
3. Run `stage1_debug.py` with the write_file and save_config prompts
|
|
4. **Pass criteria:** token IDs 128015 and 128016 appear in the output, followed by
|
|
valid JSON, followed by 128016. No Python code-dumping.
|
|
|
|
### Failure Modes & Mitigations
|
|
|
|
| Failure | Symptom | Fix |
|
|
|---------|---------|-----|
|
|
| Model still code-dumps | No 128015/128016 in output | Check verification step; increase epochs; add embed_tokens to targets |
|
|
| Model emits tokens but broken JSON | 128015 present but invalid JSON follows | Add more diverse tool-call samples; increase max_length |
|
|
| Model over-fits to tool calls | Emits start delimiter for non-tool queries | Add 30% non-tool instruction data |
|
|
| Loss doesn't decrease | Val loss flat or increasing | Check label masking (not all -100); verify data quality; lower LR |
|
|
| LoRA can't adjust embeddings | embed_tokens not in target | PEFT supports this — make sure the module name matches exactly |
|
|
|
|
## Summary
|
|
|
|
The training data is the key. The existing `prepare_data.py` already converts
|
|
Hermes-format tool calls to SmolLM3's native tokens. The chat template renders
|
|
them as the actual special token IDs. The loss mask trains only on assistant turns.
|
|
So the model will learn: "when I see tools available and the user wants me to use one,
|
|
I emit token 128015, then the JSON function call, then token 128016."
|
|
|
|
The two critical changes from the current setup:
|
|
1. **Add `embed_tokens` to LoRA target modules** so the adapter can shape the
|
|
embeddings for the tool-call tokens
|
|
2. **Add 30% non-tool instruction data** to prevent overfitting to tool calls
|
|
|
|
After training, validate with the raw token debugger before deploying to vLLM.
|