diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 8b62ff0e..89b7665c 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1469,6 +1469,14 @@ class DeepseekV4Model(nn.Module): # Pre-compute expert mapping ONCE. expert_mapping = self.get_expert_mapping() + # NVFP4 compressor/indexer scale params need special handling: + # wkv.input_scale (shape [1]) + wgate.input_scale (shape [1]) + # must be concatenated into fused_wkv_wgate.input_scale (shape [2]). + # The default stacking path fails because PerTensorScaleParameter's + # weight_loader asserts shape equality. + # We buffer them and load once both shards are available. + compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {} + for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -1482,6 +1490,29 @@ class DeepseekV4Model(nn.Module): break param = params_dict[name] weight_loader = param.weight_loader + + # NVFP4 scale params for stacked fused_wkv_wgate need + # special handling: each shard (wkv, wgate) has scale + # shape [1] or [head_dim, K], but the fused param has + # shape [2] or [2*head_dim, K]. The default stacking + # weight_loader can't handle this for PerTensorScale or + # ModelWeight scale params. Buffer and concatenate. + is_compressor_scale = ( + "fused_wkv_wgate" in name + and name.endswith(( + "input_scale", + "weight_scale", + "weight_scale_2", + )) + ) + if is_compressor_scale: + # Buffer the shard for later concatenation + if name not in compressor_scale_buffer: + compressor_scale_buffer[name] = {} + compressor_scale_buffer[name][shard_id] = loaded_weight + loaded_params.add(name) + break + weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break @@ -1542,6 +1573,24 @@ class DeepseekV4Model(nn.Module): loaded_params.add(name) continue + # Load buffered compressor/indexer scale params. + # These are NVFP4 quantization scales that need concatenation + # across shards (wkv=shard0, wgate=shard1) before loading. + for name, shards in compressor_scale_buffer.items(): + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + if len(shards) == 2: + # Concatenate shard 0 and shard 1 along dim 0 + stacked = torch.cat([shards[0], shards[1]], dim=0) + else: + stacked = shards[0] + assert param.data.shape == stacked.shape, ( + f"Scale shape mismatch for {name}: " + f"param={param.data.shape} loaded={stacked.shape}" + ) + param.data.copy_(stacked) + return loaded_params def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: