diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bbd7267fd..23035816b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -685,8 +685,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: int | None = None, + loaded_shard_id: tuple[int, ...] | int | None = None, ): + if isinstance(loaded_shard_id, tuple): + raise NotImplementedError( + "Shard id with multiple indices is not supported in weight_loader, " + "please use weight_loader_v2 instead." + ) # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -825,7 +830,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data.copy_(loaded_weight) def _load_fused_module_from_checkpoint( - self, param: BasevLLMParameter, loaded_weight: torch.Tensor + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + output_sizes: list[int] | None = None, ): """ Handle special case for models where MLP layers are already @@ -839,7 +847,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): current_shard_offset = 0 shard_offsets: list[tuple[int, int, int]] = [] - for i, output_size in enumerate(self.output_sizes): + output_sizes = output_sizes or self.output_sizes + for i, output_size in enumerate(output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size @@ -864,17 +873,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear): self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_shard_id: int | None = None, + loaded_shard_id: tuple[int, ...] | int | None = None, ): - if loaded_shard_id is None: + if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if isinstance(param, PerTensorScaleParameter): param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + output_sizes = ( + [self.output_sizes[idx] for idx in loaded_shard_id] + if loaded_shard_id + else None + ) + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = getattr(self, "weight_block_size", None) + output_sizes = [ + adjust_block_scale_shard(weight_block_size, size, 0)[0] + for size in (output_sizes or self.output_sizes) + ] # TODO: @dsikka - move to parameter.py - self._load_fused_module_from_checkpoint(param, loaded_weight) + self._load_fused_module_from_checkpoint( + param, loaded_weight, output_sizes=output_sizes + ) return assert loaded_shard_id < len(self.output_sizes) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 5c76bf7ef..7c355e8b0 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -30,36 +30,20 @@ from collections.abc import Callable, Iterable import torch from einops import rearrange from torch import nn -from transformers.activations import ACT2FN from vllm.compilation.decorators import support_torch_compile from vllm.config import ( - CacheConfig, - ModelConfig, - SpeculativeConfig, VllmConfig, - get_current_vllm_config, ) from vllm.distributed import ( - divide, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3_5RMSNorm, ) -from vllm.model_executor.layers.layernorm import RMSNormGated -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - mamba_v2_sharded_weight_loader, -) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, MambaStateCopyFuncCalculator, @@ -73,11 +57,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, - sharded_weight_loader, ) -from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.qwen3_5 import ( Qwen3_5Config, @@ -99,7 +80,6 @@ from .interfaces import ( ) from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from .qwen3_next import ( - ChunkGatedDeltaRule, Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextGatedDeltaNet, @@ -139,154 +119,31 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): - def __init__( - self, - config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - speculative_config: SpeculativeConfig | None = None, - prefix: str = "", - ) -> None: - super(Qwen3NextGatedDeltaNet, self).__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_idx = extract_layer_index(prefix) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.layer_norm_epsilon = config.rms_norm_eps - self.prefix = prefix - - self.config = config - self.model_config = model_config - self.cache_config = cache_config - self.quant_config = quant_config - self.speculative_config = speculative_config - self.num_spec = ( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ) - - # QKV - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - prefix=f"{prefix}.conv1d", - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj_qkv = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.key_dim, self.key_dim, self.value_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkv", - ) - self.in_proj_z = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.value_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_z", - ) - self.in_proj_b = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.num_v_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_b", - ) - self.in_proj_a = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.num_v_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_a", - ) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.tp_size, - self.tp_rank, - ) - }, - ) - - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), - ) - self.A_log = nn.Parameter( - torch.empty( - divide(self.num_v_heads, self.tp_size), - ) - ) - - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=current_platform.current_device(), - dtype=config.dtype, - ) - - self.out_proj = RowParallelLinear( - self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - self.chunk_gated_delta_rule = ChunkGatedDeltaRule() - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - def fix_query_key_value_ordering( self, - mixed_qkv, - z, - b, - a, + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, ): raise NotImplementedError( "Qwen3.5 Series dont need to fix query key value ordering" ) + def create_qkvz_proj( + self, + hidden_size: int, + key_dim: int, + value_dim: int, + quant_config: QuantizationConfig | None, + prefix: str, + ) -> MergedColumnParallelLinear: + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[key_dim, key_dim, value_dim, value_dim], + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + def forward( self, hidden_states: torch.Tensor, @@ -303,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): # ============================================================ # Part 1: Input Projection # ============================================================ - mixed_qkv, _ = self.in_proj_qkv(hidden_states) - z, _ = self.in_proj_z(hidden_states) + mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) + qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size + z_size = self.value_dim // self.tp_size + mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) z = z.reshape(z.size(0), -1, self.head_v_dim) - b, _ = self.in_proj_b(hidden_states) - a, _ = self.in_proj_a(hidden_states) + ba, _ = self.in_proj_ba(hidden_states) + b, a = ba.chunk(2, dim=-1) b = b.contiguous() a = a.contiguous() @@ -506,11 +365,18 @@ class Qwen3_5Model(Qwen3NextModel): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) + # self attention ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + # mlp ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + # GDN + ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), + ("in_proj_qkvz", "in_proj_z", 3), + ("in_proj_ba", "in_proj_b", 0), + ("in_proj_ba", "in_proj_a", 1), ] params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 6da5bca1b..59468c7bf 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, @@ -406,19 +407,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) # projection of the input hidden states - self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - self.projection_size_ba = self.num_v_heads * 2 - self.in_proj_qkvz = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.projection_size_qkvz, - bias=False, + # Qwen3-Next and Qwen3.5 has a different qkv_proj layout, + # we need to create qkvz_proj adaptively here. + self.in_proj_qkvz = self.create_qkvz_proj( + hidden_size=self.hidden_size, + key_dim=self.key_dim, + value_dim=self.value_dim, quant_config=quant_config, prefix=f"{prefix}.in_proj_qkvz", ) # ba_proj doesn't support blockwise fp8 quantization. - self.in_proj_ba = ColumnParallelLinear( + self.in_proj_ba = MergedColumnParallelLinear( input_size=self.hidden_size, - output_size=self.projection_size_ba, + output_sizes=[self.num_v_heads] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_ba", @@ -484,10 +485,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self + def create_qkvz_proj( + self, + hidden_size: int, + key_dim: int, + value_dim: int, + quant_config: QuantizationConfig | None, + prefix: str, + ) -> MergedColumnParallelLinear: + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvz", + ) + def fix_query_key_value_ordering( self, - mixed_qkvz, - mixed_ba, + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, ): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.