[Core] Refactor ColumnParallelLinear: remove unused parameter and optimize forward (#31939)

Signed-off-by: maang <maang_h@163.com>
This commit is contained in:
maang
2026-01-10 12:19:49 +08:00
committed by GitHub
parent c60578de0a
commit 52d428295d

View File

@@ -411,10 +411,10 @@ class ReplicatedLinear(LinearBase):
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
@@ -444,8 +444,6 @@ class ColumnParallelLinear(LinearBase):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
@@ -463,7 +461,6 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -495,9 +492,6 @@ class ColumnParallelLinear(LinearBase):
self._maybe_allow_fp8_block_shape_mismatch()
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
@@ -614,9 +608,10 @@ class ColumnParallelLinear(LinearBase):
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
@@ -1469,10 +1464,9 @@ class RowParallelLinear(LinearBase):
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str: