Fix zero-dim tensor concatenation in compressor scale buffer

input_scale and weight_scale_2 are 0-dim scalars in the NVFP4 checkpoint.
torch.cat can't concatenate scalars — reshape to 1-d first.
This commit is contained in:
2026-05-19 00:10:13 +00:00
parent d41a48aa1f
commit 201a40e6c4

View File

@@ -1611,8 +1611,15 @@ class DeepseekV4Model(nn.Module):
continue
param = params_dict[name]
if len(shards) == 2:
# Concatenate shard 0 and shard 1 along dim 0
stacked = torch.cat([shards[0], shards[1]], dim=0)
# Concatenate shard 0 and shard 1 along dim 0.
# Scales may be 0-dim scalars (input_scale, weight_scale_2)
# or N-dim tensors (weight_scale); reshape scalars to 1-d.
s0, s1 = shards[0], shards[1]
if s0.ndim == 0:
s0 = s0.reshape(1)
if s1.ndim == 0:
s1 = s1.reshape(1)
stacked = torch.cat([s0, s1], dim=0)
else:
stacked = shards[0]
assert param.data.shape == stacked.shape, (