[Bugfix][LoRA] Fix Qwen35 LoRA (#36976)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2026-03-20 11:09:32 +08:00
committed by GitHub
parent ea2c148fa7
commit 8fbe3f303f
5 changed files with 257 additions and 46 deletions

View File

@@ -8,7 +8,7 @@ steps:
- vllm/lora
- tests/lora
commands:
- pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py
- pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py --ignore=lora/test_qwen35_densemoel_lora.py
parallelism: 4
@@ -30,4 +30,5 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- pytest -v -s -x lora/test_qwen35_densemoel_lora.py

View File

@@ -294,6 +294,11 @@ def whisper_lora_files():
return snapshot_download(repo_id="chengyili2005/whisper-small-mandarin-lora")
@pytest.fixture(scope="session")
def qwen35_dense_model_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen35-4b-text-only-sql-lora")
@pytest.fixture
def reset_default_device():
"""

View File

@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
import vllm
import vllm.config
from vllm.lora.request import LoRARequest
from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "Qwen/Qwen3.5-4B"
PROMPT_TEMPLATE = """Write a SQL query for the given database.\nSchema:\nTables:\n - stadium(Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average)\n - singer(Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male)\n - concert(concert_ID, concert_Name, Theme, Stadium_ID, Year)\n - singer_in_concert(concert_ID, Singer_ID)\n\nQuestion:\n{query}""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'",
"SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)",
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=(
"What is the average, minimum, and maximum "
"age of all singers from France?"
)
),
PROMPT_TEMPLATE.format(
query=("What are the names of the stadiums without any concerts?")
),
]
input_templates = []
for prmpt in prompts:
messages = [{"role": "user", "content": prmpt}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # disable thinking
)
input_templates.append(prompt)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=512)
outputs = llm.generate(
input_templates,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@create_new_process_for_each_test()
def test_qwen35_dense_model_lora(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=512,
enable_lora=True,
max_loras=2,
max_num_seqs=16,
max_lora_rank=8,
trust_remote_code=True,
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@multi_gpu_test(num_gpus=4)
def test_qwen35_dense_model_lora_tp4(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
print(output1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@multi_gpu_test(num_gpus=4)
def test_qwen35_dense_model_lora_tp4_fully_sharded_loras(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=512,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
gpu_memory_utilization=0.8,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

View File

@@ -32,9 +32,7 @@ from einops import rearrange
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
VllmConfig,
)
from vllm.config import VllmConfig
from vllm.distributed import (
get_pp_group,
)
@@ -42,7 +40,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm,
)
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
@@ -130,6 +131,40 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
"Qwen3.5 Series dont need to fix query key value ordering"
)
def __init__(
self,
config: Qwen3_5Config,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
create_in_proj_qkvz = vllm_config.lora_config is None
super().__init__(
config,
vllm_config=vllm_config,
prefix=prefix,
create_in_proj_qkvz=create_in_proj_qkvz,
)
if vllm_config.lora_config is not None:
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
# output sizes [key_dim, key_dim, value_dim] are not representable
# with a single QKVParallelLinear (which ties K and V head counts).
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_z",
)
def create_qkvz_proj(
self,
hidden_size: int,
@@ -180,15 +215,21 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
if hasattr(self, "in_proj_qkv"):
# LoRA path: separate in_proj_qkv and in_proj_z
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states)
else:
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b, a = ba.chunk(2, dim=-1)
@@ -240,18 +281,14 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
if self.layer_type == "linear_attention":
self.linear_attn = Qwen3_5GatedDeltaNet(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
config=config,
vllm_config=vllm_config,
prefix=f"{prefix}.linear_attn",
)
elif self.layer_type == "full_attention":
@@ -331,6 +368,7 @@ class Qwen3_5Model(Qwen3NextModel):
self.num_redundant_experts = eplb_config.num_redundant_experts
self.config = config
self.enable_lora = vllm_config.lora_config is not None
self.vocab_size = config.vocab_size
@@ -396,13 +434,25 @@ class Qwen3_5Model(Qwen3NextModel):
# mlp
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
]
if self.enable_lora:
stacked_params_mapping.extend(
[
("in_proj_qkv", "in_proj_qkv", (0, 1, 2)),
("in_proj_z", "in_proj_z", 0),
]
)
else:
stacked_params_mapping.extend(
[
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
]
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
@@ -450,7 +500,10 @@ class Qwen3_5Model(Qwen3NextModel):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
if param_name == "in_proj_z" and self.enable_lora:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
@@ -580,6 +633,15 @@ class Qwen3_5ForCausalLMBase(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
# instead of merged in_proj_qkvz; pack mapping must match.
if vllm_config.lora_config:
base = getattr(Qwen3_5ForCausalLMBase, "packed_modules_mapping", {})
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
self.packed_modules_mapping["in_proj_z"] = ["in_proj_z"]
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
@@ -672,6 +734,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
@@ -699,6 +762,16 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
self.language_model.make_empty_intermediate_tensors
)
def update_packed_mapping(self, enable_lora: bool):
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
if enable_lora:
base = getattr(
Qwen3_5ForConditionalGeneration, "packed_modules_mapping", {}
)
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
def embed_input_ids(
self,
input_ids: torch.Tensor,
@@ -879,9 +952,13 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
class Qwen3_5MoeForConditionalGeneration(
Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts
):
# For MoE LoRA weights loading
is_3d_moe_weight: bool = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config

View File

@@ -15,7 +15,6 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
ModelConfig,
SpeculativeConfig,
VllmConfig,
get_current_vllm_config,
)
@@ -401,11 +400,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def __init__(
self,
config: Qwen3NextConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
speculative_config: SpeculativeConfig | None = None,
vllm_config: VllmConfig,
prefix: str = "",
create_in_proj_qkvz: bool = True,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -432,10 +429,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.speculative_config = vllm_config.speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
@@ -455,13 +452,16 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# we need to create qkvz_proj adaptively here.
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
# When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5),
# the subclass creates in_proj_qkv and in_proj_z separately.
if create_in_proj_qkvz:
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
# Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
# layouts, so we use a factory method to create the projection.
@@ -1207,7 +1207,6 @@ class Qwen3NextDecoderLayer(nn.Module):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
@@ -1215,10 +1214,7 @@ class Qwen3NextDecoderLayer(nn.Module):
if self.layer_type == "linear_attention":
self.linear_attn = Qwen3NextGatedDeltaNet(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
vllm_config=vllm_config,
prefix=f"{prefix}.linear_attn",
)
elif self.layer_type == "full_attention":