diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 29f1c4f6..de36f4ff 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1611,8 +1611,15 @@ class DeepseekV4Model(nn.Module): continue param = params_dict[name] if len(shards) == 2: - # Concatenate shard 0 and shard 1 along dim 0 - stacked = torch.cat([shards[0], shards[1]], dim=0) + # Concatenate shard 0 and shard 1 along dim 0. + # Scales may be 0-dim scalars (input_scale, weight_scale_2) + # or N-dim tensors (weight_scale); reshape scalars to 1-d. + s0, s1 = shards[0], shards[1] + if s0.ndim == 0: + s0 = s0.reshape(1) + if s1.ndim == 0: + s1 = s1.reshape(1) + stacked = torch.cat([s0, s1], dim=0) else: stacked = shards[0] assert param.data.shape == stacked.shape, (