Fix zero-dim tensor concatenation in compressor scale buffer
input_scale and weight_scale_2 are 0-dim scalars in the NVFP4 checkpoint. torch.cat can't concatenate scalars — reshape to 1-d first.
This commit is contained in:
@@ -1611,8 +1611,15 @@ class DeepseekV4Model(nn.Module):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if len(shards) == 2:
|
||||
# Concatenate shard 0 and shard 1 along dim 0
|
||||
stacked = torch.cat([shards[0], shards[1]], dim=0)
|
||||
# Concatenate shard 0 and shard 1 along dim 0.
|
||||
# Scales may be 0-dim scalars (input_scale, weight_scale_2)
|
||||
# or N-dim tensors (weight_scale); reshape scalars to 1-d.
|
||||
s0, s1 = shards[0], shards[1]
|
||||
if s0.ndim == 0:
|
||||
s0 = s0.reshape(1)
|
||||
if s1.ndim == 0:
|
||||
s1 = s1.reshape(1)
|
||||
stacked = torch.cat([s0, s1], dim=0)
|
||||
else:
|
||||
stacked = shards[0]
|
||||
assert param.data.shape == stacked.shape, (
|
||||
|
||||
Reference in New Issue
Block a user