Add embed_tokens to LoRA targets + token ID verification before training

This commit is contained in:
Jinx
2026-04-10 17:20:06 +00:00
parent f46995690c
commit d3b5f04f88

View File

@@ -154,6 +154,7 @@ def main():
target_modules=[ target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", "q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj", "gate_proj", "up_proj", "down_proj",
"embed_tokens", # Critical: lets LoRA adjust tool-call token embeddings
], ],
bias="none", bias="none",
) )
@@ -166,6 +167,21 @@ def main():
val_data = load_jsonl(data_dir / "val.jsonl") val_data = load_jsonl(data_dir / "val.jsonl")
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}") print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
# ── Verify tool-call tokens are in the training data ─────────────────
print("Verifying tool-call token IDs in training data ...")
verification_sample = train_data[0]
v_text = tokenizer.apply_chat_template(verification_sample["messages"], tokenize=False)
v_ids = tokenizer.encode(v_text)
tc_start_found = 128015 in v_ids
tc_end_found = 128016 in v_ids
if tc_start_found and tc_end_found:
print(f" ✓ Tool-call tokens verified in sample data (128015={tc_start_found}, 128016={tc_end_found})")
else:
print(f" ✗ WARNING: Tool-call tokens missing! (128015={tc_start_found}, 128016={tc_end_found})")
print(f" ✗ Training may NOT teach the model to emit tool-call tokens.")
print(f" ✗ Check prepare_data.py and the tokenizer chat template.")
# Don't abort — let the user decide, but warn loudly
print("Tokenizing training data ...") print("Tokenizing training data ...")
train_dataset = Dataset.from_list(train_data).map( train_dataset = Dataset.from_list(train_data).map(
lambda x: tokenize_for_training(x, tokenizer, args.max_length), lambda x: tokenize_for_training(x, tokenizer, args.max_length),