Fix NVFP4 compressor scale loading: buffer and concatenate scale shards
The stacked params mapping (wkv + wgate → fused_wkv_wgate) uses weight_loader(param, weight, shard_id), but PerTensorScaleParameter and ModelWeightParameter for NVFP4 scale params don't support shard_id in load_column_parallel_weight (asserts shape equality). Fix: buffer input_scale, weight_scale, weight_scale_2 for fused_wkv_wgate shards, then concatenate along dim 0 and copy_ into the param after all weights are loaded.
This commit is contained in:
@@ -1469,6 +1469,14 @@ class DeepseekV4Model(nn.Module):
|
||||
# Pre-compute expert mapping ONCE.
|
||||
expert_mapping = self.get_expert_mapping()
|
||||
|
||||
# NVFP4 compressor/indexer scale params need special handling:
|
||||
# wkv.input_scale (shape [1]) + wgate.input_scale (shape [1])
|
||||
# must be concatenated into fused_wkv_wgate.input_scale (shape [2]).
|
||||
# The default stacking path fails because PerTensorScaleParameter's
|
||||
# weight_loader asserts shape equality.
|
||||
# We buffer them and load once both shards are available.
|
||||
compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
@@ -1482,6 +1490,29 @@ class DeepseekV4Model(nn.Module):
|
||||
break
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
# NVFP4 scale params for stacked fused_wkv_wgate need
|
||||
# special handling: each shard (wkv, wgate) has scale
|
||||
# shape [1] or [head_dim, K], but the fused param has
|
||||
# shape [2] or [2*head_dim, K]. The default stacking
|
||||
# weight_loader can't handle this for PerTensorScale or
|
||||
# ModelWeight scale params. Buffer and concatenate.
|
||||
is_compressor_scale = (
|
||||
"fused_wkv_wgate" in name
|
||||
and name.endswith((
|
||||
"input_scale",
|
||||
"weight_scale",
|
||||
"weight_scale_2",
|
||||
))
|
||||
)
|
||||
if is_compressor_scale:
|
||||
# Buffer the shard for later concatenation
|
||||
if name not in compressor_scale_buffer:
|
||||
compressor_scale_buffer[name] = {}
|
||||
compressor_scale_buffer[name][shard_id] = loaded_weight
|
||||
loaded_params.add(name)
|
||||
break
|
||||
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
@@ -1542,6 +1573,24 @@ class DeepseekV4Model(nn.Module):
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
# Load buffered compressor/indexer scale params.
|
||||
# These are NVFP4 quantization scales that need concatenation
|
||||
# across shards (wkv=shard0, wgate=shard1) before loading.
|
||||
for name, shards in compressor_scale_buffer.items():
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if len(shards) == 2:
|
||||
# Concatenate shard 0 and shard 1 along dim 0
|
||||
stacked = torch.cat([shards[0], shards[1]], dim=0)
|
||||
else:
|
||||
stacked = shards[0]
|
||||
assert param.data.shape == stacked.shape, (
|
||||
f"Scale shape mismatch for {name}: "
|
||||
f"param={param.data.shape} loaded={stacked.shape}"
|
||||
)
|
||||
param.data.copy_(stacked)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
Reference in New Issue
Block a user