[Bugfix][LoRA] Fix Qwen35 LoRA (#36976)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ steps:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
commands:
|
||||
- pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py
|
||||
- pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py --ignore=lora/test_qwen35_densemoel_lora.py
|
||||
parallelism: 4
|
||||
|
||||
|
||||
@@ -30,4 +30,5 @@ steps:
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||
- pytest -v -s -x lora/test_qwen35_densemoel_lora.py
|
||||
@@ -294,6 +294,11 @@ def whisper_lora_files():
|
||||
return snapshot_download(repo_id="chengyili2005/whisper-small-mandarin-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen35_dense_model_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen35-4b-text-only-sql-lora")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
|
||||
132
tests/lora/test_qwen35_densemoel_lora.py
Normal file
132
tests/lora/test_qwen35_densemoel_lora.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import vllm
|
||||
import vllm.config
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen3.5-4B"
|
||||
|
||||
PROMPT_TEMPLATE = """Write a SQL query for the given database.\nSchema:\nTables:\n - stadium(Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average)\n - singer(Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male)\n - concert(concert_ID, concert_Name, Theme, Stadium_ID, Year)\n - singer_in_concert(concert_ID, Singer_ID)\n\nQuestion:\n{query}""" # noqa: E501
|
||||
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
"SELECT count(*) FROM singer",
|
||||
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'",
|
||||
"SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)",
|
||||
]
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
|
||||
PROMPT_TEMPLATE.format(
|
||||
query=(
|
||||
"What is the average, minimum, and maximum "
|
||||
"age of all singers from France?"
|
||||
)
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
query=("What are the names of the stadiums without any concerts?")
|
||||
),
|
||||
]
|
||||
input_templates = []
|
||||
for prmpt in prompts:
|
||||
messages = [{"role": "user", "content": prmpt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # disable thinking
|
||||
)
|
||||
input_templates.append(prompt)
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=512)
|
||||
outputs = llm.generate(
|
||||
input_templates,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_qwen35_dense_model_lora(qwen35_dense_model_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=512,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_num_seqs=16,
|
||||
max_lora_rank=8,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_qwen35_dense_model_lora_tp4(qwen35_dense_model_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=8,
|
||||
max_num_seqs=16,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=False,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
|
||||
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
|
||||
print(output1)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_qwen35_dense_model_lora_tp4_fully_sharded_loras(qwen35_dense_model_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=512,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=8,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
@@ -32,9 +32,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
)
|
||||
@@ -42,7 +40,10 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
GemmaRMSNorm as Qwen3_5RMSNorm,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
@@ -130,6 +131,40 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
"Qwen3.5 Series dont need to fix query key value ordering"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3_5Config,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
create_in_proj_qkvz = vllm_config.lora_config is None
|
||||
super().__init__(
|
||||
config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
create_in_proj_qkvz=create_in_proj_qkvz,
|
||||
)
|
||||
if vllm_config.lora_config is not None:
|
||||
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
|
||||
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
|
||||
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
|
||||
# output sizes [key_dim, key_dim, value_dim] are not representable
|
||||
# with a single QKVParallelLinear (which ties K and V head counts).
|
||||
self.in_proj_qkv = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=f"{prefix}.in_proj_qkv",
|
||||
)
|
||||
self.in_proj_z = ColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.value_dim,
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=f"{prefix}.in_proj_z",
|
||||
)
|
||||
|
||||
def create_qkvz_proj(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -180,15 +215,21 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
# ============================================================
|
||||
# Part 1: Input Projection
|
||||
# ============================================================
|
||||
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
|
||||
hidden_states,
|
||||
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
|
||||
sum(self.in_proj_ba.output_sizes) // self.tp_size,
|
||||
self.prefix,
|
||||
)
|
||||
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
|
||||
z_size = self.value_dim // self.tp_size
|
||||
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
|
||||
if hasattr(self, "in_proj_qkv"):
|
||||
# LoRA path: separate in_proj_qkv and in_proj_z
|
||||
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
|
||||
ba, _ = self.in_proj_ba(hidden_states)
|
||||
z, _ = self.in_proj_z(hidden_states)
|
||||
else:
|
||||
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
|
||||
hidden_states,
|
||||
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
|
||||
sum(self.in_proj_ba.output_sizes) // self.tp_size,
|
||||
self.prefix,
|
||||
)
|
||||
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
|
||||
z_size = self.value_dim // self.tp_size
|
||||
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
|
||||
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
||||
b, a = ba.chunk(2, dim=-1)
|
||||
|
||||
@@ -240,18 +281,14 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = Qwen3_5GatedDeltaNet(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
speculative_config=speculative_config,
|
||||
config=config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.linear_attn",
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
@@ -331,6 +368,7 @@ class Qwen3_5Model(Qwen3NextModel):
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
self.config = config
|
||||
self.enable_lora = vllm_config.lora_config is not None
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@@ -396,13 +434,25 @@ class Qwen3_5Model(Qwen3NextModel):
|
||||
# mlp
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
# GDN
|
||||
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
|
||||
("in_proj_qkvz", "in_proj_z", 3),
|
||||
("in_proj_ba", "in_proj_b", 0),
|
||||
("in_proj_ba", "in_proj_a", 1),
|
||||
]
|
||||
|
||||
if self.enable_lora:
|
||||
stacked_params_mapping.extend(
|
||||
[
|
||||
("in_proj_qkv", "in_proj_qkv", (0, 1, 2)),
|
||||
("in_proj_z", "in_proj_z", 0),
|
||||
]
|
||||
)
|
||||
else:
|
||||
stacked_params_mapping.extend(
|
||||
[
|
||||
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
|
||||
("in_proj_qkvz", "in_proj_z", 3),
|
||||
]
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
@@ -450,7 +500,10 @@ class Qwen3_5Model(Qwen3NextModel):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
if param_name == "in_proj_z" and self.enable_lora:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
@@ -580,6 +633,15 @@ class Qwen3_5ForCausalLMBase(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
|
||||
# instead of merged in_proj_qkvz; pack mapping must match.
|
||||
if vllm_config.lora_config:
|
||||
base = getattr(Qwen3_5ForCausalLMBase, "packed_modules_mapping", {})
|
||||
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
|
||||
self.packed_modules_mapping.pop("in_proj_qkvz", None)
|
||||
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
|
||||
self.packed_modules_mapping["in_proj_z"] = ["in_proj_z"]
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
@@ -672,6 +734,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
|
||||
# protocols have not __init__ method, so we need to use nn.Module.__init__
|
||||
nn.Module.__init__(self)
|
||||
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
|
||||
config: Qwen3_5Config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@@ -699,6 +762,16 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def update_packed_mapping(self, enable_lora: bool):
|
||||
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
|
||||
if enable_lora:
|
||||
base = getattr(
|
||||
Qwen3_5ForConditionalGeneration, "packed_modules_mapping", {}
|
||||
)
|
||||
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
|
||||
self.packed_modules_mapping.pop("in_proj_qkvz", None)
|
||||
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -879,9 +952,13 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
|
||||
class Qwen3_5MoeForConditionalGeneration(
|
||||
Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts
|
||||
):
|
||||
# For MoE LoRA weights loading
|
||||
is_3d_moe_weight: bool = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
|
||||
# protocols have not __init__ method, so we need to use nn.Module.__init__
|
||||
nn.Module.__init__(self)
|
||||
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
|
||||
config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
)
|
||||
@@ -401,11 +400,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
speculative_config: SpeculativeConfig | None = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
create_in_proj_qkvz: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -432,10 +429,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.speculative_config = speculative_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.num_spec = (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config
|
||||
@@ -455,13 +452,16 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
# projection of the input hidden states
|
||||
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
|
||||
# we need to create qkvz_proj adaptively here.
|
||||
self.in_proj_qkvz = self.create_qkvz_proj(
|
||||
hidden_size=self.hidden_size,
|
||||
key_dim=self.key_dim,
|
||||
value_dim=self.value_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj_qkvz",
|
||||
)
|
||||
# When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5),
|
||||
# the subclass creates in_proj_qkv and in_proj_z separately.
|
||||
if create_in_proj_qkvz:
|
||||
self.in_proj_qkvz = self.create_qkvz_proj(
|
||||
hidden_size=self.hidden_size,
|
||||
key_dim=self.key_dim,
|
||||
value_dim=self.value_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj_qkvz",
|
||||
)
|
||||
# ba_proj doesn't support blockwise fp8 quantization.
|
||||
# Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
|
||||
# layouts, so we use a factory method to create the projection.
|
||||
@@ -1207,7 +1207,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
@@ -1215,10 +1214,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = Qwen3NextGatedDeltaNet(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
speculative_config=speculative_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.linear_attn",
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
|
||||
Reference in New Issue
Block a user