[Bugfix][Model] Support LoRA on Qwen3 Output Embedding (#29816)
Signed-off-by: kurt <kurt@thinkingmachines.ai>
This commit is contained in:
100
tests/lora/test_qwen3_unembed.py
Normal file
100
tests/lora/test_qwen3_unembed.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for Qwen3 unembed LoRA support.
|
||||
|
||||
This test creates synthetic LoRA weights that include lm_head (output embedding)
|
||||
to verify that Qwen3 properly supports LoRA on the unembed/lm_head layer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen3-0.6B"
|
||||
HIDDEN_SIZE = 1024
|
||||
VOCAB_SIZE = 151936
|
||||
|
||||
|
||||
def create_qwen3_lora_with_lm_head(save_dir: str, rank: int = 8) -> None:
|
||||
"""Create synthetic Qwen3 LoRA weights with lm_head."""
|
||||
lora_weights = {}
|
||||
for module in ["q_proj", "v_proj"]:
|
||||
lora_A = torch.from_numpy(
|
||||
np.random.randn(rank, HIDDEN_SIZE).astype(np.float16) * 0.01
|
||||
)
|
||||
lora_B = torch.zeros(HIDDEN_SIZE, rank, dtype=torch.float16)
|
||||
key_prefix = f"base_model.model.model.layers.0.self_attn.{module}"
|
||||
lora_weights[f"{key_prefix}.lora_A.weight"] = lora_A
|
||||
lora_weights[f"{key_prefix}.lora_B.weight"] = lora_B
|
||||
|
||||
# lm_head LoRA weights
|
||||
lora_weights["base_model.model.lm_head.lora_A.weight"] = torch.from_numpy(
|
||||
np.random.randn(rank, HIDDEN_SIZE).astype(np.float16) * 0.01
|
||||
)
|
||||
lora_weights["base_model.model.lm_head.lora_B.weight"] = torch.zeros(
|
||||
VOCAB_SIZE, rank, dtype=torch.float16
|
||||
)
|
||||
|
||||
adapter_config = {
|
||||
"peft_type": "LORA",
|
||||
"base_model_name_or_path": MODEL_PATH,
|
||||
"task_type": "CAUSAL_LM",
|
||||
"inference_mode": True,
|
||||
"r": rank,
|
||||
"lora_alpha": rank * 2,
|
||||
"lora_dropout": 0.0,
|
||||
"bias": "none",
|
||||
"target_modules": ["q_proj", "v_proj", "lm_head"],
|
||||
}
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
with open(os.path.join(save_dir, "adapter_config.json"), "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
save_file(lora_weights, os.path.join(save_dir, "adapter_model.safetensors"))
|
||||
|
||||
|
||||
def test_qwen3_unembed_lora():
|
||||
"""Verify Qwen3 can load and generate with LoRA adapters with lm_head."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Initialize engine first (before creating torch tensors)
|
||||
llm = LLM(
|
||||
model=MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=8,
|
||||
max_model_len=128,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
# Create LoRA weights after engine init
|
||||
create_qwen3_lora_with_lm_head(tmpdir, rank=8)
|
||||
|
||||
lora_request = LoRARequest("lm_head_lora", 1, tmpdir)
|
||||
llm.llm_engine.add_lora(lora_request)
|
||||
|
||||
assert 1 in llm.llm_engine.list_loras(), "lm_head LoRA should be loaded"
|
||||
|
||||
# Test generation
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=32)
|
||||
prompts = ["Hello, my name is"]
|
||||
|
||||
# Generate with base model (no LoRA)
|
||||
base_outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
assert len(base_outputs) == 1
|
||||
assert len(base_outputs[0].outputs[0].text) > 0
|
||||
|
||||
# Generate with lm_head LoRA
|
||||
lora_outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_request, use_tqdm=False
|
||||
)
|
||||
assert len(lora_outputs) == 1
|
||||
assert len(lora_outputs[0].outputs[0].text) > 0
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
@@ -148,7 +147,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
||||
if hasattr(lm_head, "base_layer"):
|
||||
actual_lm_head = lm_head.base_layer
|
||||
else:
|
||||
actual_lm_head = lm_head
|
||||
logits = actual_lm_head.quant_method.apply(actual_lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -452,16 +452,23 @@ class LoRAModelManager:
|
||||
if module_name not in self.packed_modules:
|
||||
assert embedding_modules is not None
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (
|
||||
module.base_layer.org_vocab_size
|
||||
if hasattr(module.base_layer, "org_vocab_size")
|
||||
else module.base_layer.weight.shape[1]
|
||||
)
|
||||
output_dim = (
|
||||
module.base_layer.embedding_dim
|
||||
if hasattr(module.base_layer, "embedding_dim")
|
||||
else module.base_layer.weight.shape[0]
|
||||
)
|
||||
# Special-case lm_head: wrapped by LogitsProcessorWithLoRA.
|
||||
# LoRA input dim is hidden_size, output dim is vocab size.
|
||||
# LogitsProcessorWithLoRA handles extra vocab size directly.
|
||||
if parts[-1] == "lm_head":
|
||||
input_dim = module.lora_a_stacked[0].shape[-1]
|
||||
output_dim = module.lora_b_stacked[0].shape[-2]
|
||||
else:
|
||||
input_dim = (
|
||||
module.base_layer.org_vocab_size
|
||||
if hasattr(module.base_layer, "org_vocab_size")
|
||||
else module.base_layer.weight.shape[1]
|
||||
)
|
||||
output_dim = (
|
||||
module.base_layer.embedding_dim
|
||||
if hasattr(module.base_layer, "embedding_dim")
|
||||
else module.base_layer.weight.shape[0]
|
||||
)
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
input_dim,
|
||||
|
||||
@@ -271,6 +271,11 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
],
|
||||
}
|
||||
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@@ -689,6 +689,11 @@ class Qwen3MoeForCausalLM(
|
||||
]
|
||||
}
|
||||
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
Reference in New Issue
Block a user