Fix NVFP4 compressor scale loading: buffer and concatenate scale shards

The stacked params mapping (wkv + wgate → fused_wkv_wgate) uses
weight_loader(param, weight, shard_id), but PerTensorScaleParameter
and ModelWeightParameter for NVFP4 scale params don't support shard_id
in load_column_parallel_weight (asserts shape equality).

Fix: buffer input_scale, weight_scale, weight_scale_2 for fused_wkv_wgate
shards, then concatenate along dim 0 and copy_ into the param after all
weights are loaded.
This commit is contained in:
2026-05-18 23:24:08 +00:00
parent f74447bfd0
commit eef0ef76af

View File

@@ -1469,6 +1469,14 @@ class DeepseekV4Model(nn.Module):
# Pre-compute expert mapping ONCE.
expert_mapping = self.get_expert_mapping()
# NVFP4 compressor/indexer scale params need special handling:
# wkv.input_scale (shape [1]) + wgate.input_scale (shape [1])
# must be concatenated into fused_wkv_wgate.input_scale (shape [2]).
# The default stacking path fails because PerTensorScaleParameter's
# weight_loader asserts shape equality.
# We buffer them and load once both shards are available.
compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {}
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@@ -1482,6 +1490,29 @@ class DeepseekV4Model(nn.Module):
break
param = params_dict[name]
weight_loader = param.weight_loader
# NVFP4 scale params for stacked fused_wkv_wgate need
# special handling: each shard (wkv, wgate) has scale
# shape [1] or [head_dim, K], but the fused param has
# shape [2] or [2*head_dim, K]. The default stacking
# weight_loader can't handle this for PerTensorScale or
# ModelWeight scale params. Buffer and concatenate.
is_compressor_scale = (
"fused_wkv_wgate" in name
and name.endswith((
"input_scale",
"weight_scale",
"weight_scale_2",
))
)
if is_compressor_scale:
# Buffer the shard for later concatenation
if name not in compressor_scale_buffer:
compressor_scale_buffer[name] = {}
compressor_scale_buffer[name][shard_id] = loaded_weight
loaded_params.add(name)
break
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
@@ -1542,6 +1573,24 @@ class DeepseekV4Model(nn.Module):
loaded_params.add(name)
continue
# Load buffered compressor/indexer scale params.
# These are NVFP4 quantization scales that need concatenation
# across shards (wkv=shard0, wgate=shard1) before loading.
for name, shards in compressor_scale_buffer.items():
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if len(shards) == 2:
# Concatenate shard 0 and shard 1 along dim 0
stacked = torch.cat([shards[0], shards[1]], dim=0)
else:
stacked = shards[0]
assert param.data.shape == stacked.shape, (
f"Scale shape mismatch for {name}: "
f"param={param.data.shape} loaded={stacked.shape}"
)
param.data.copy_(stacked)
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: