[Bugfix] Handle process_weights_after_loading for QKVCrossParallelLinear (#15328)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-04-09 01:02:23 +08:00
committed by GitHub
parent 4ebc0b9640
commit 40b4284fe3
3 changed files with 33 additions and 6 deletions

View File

@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.q_size = self.q_proj_decoder.output_size_per_partition
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias:
@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else:
self.bias = None
def process_weights_after_loading(self):
for layer in self.proj.values():
if self.quant_method is not None:
self.quant_method.process_weights_after_loading(layer)
@property
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="q_proj_decoder")
return layer
@property
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="kv_proj_encoder")
return layer
def sync_weight_attrs(
@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
else:
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
s += f", q_size={self.q_size}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"