[Core] Refactor ColumnParallelLinear: remove unused parameter and optimize forward (#31939)
Signed-off-by: maang <maang_h@163.com>
This commit is contained in:
@@ -411,10 +411,10 @@ class ReplicatedLinear(LinearBase):
|
||||
assert self.quant_method is not None
|
||||
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
@@ -444,8 +444,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
@@ -463,7 +461,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
output_sizes: list[int] | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -495,9 +492,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
self._maybe_allow_fp8_block_shape_mismatch()
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
@@ -614,9 +608,10 @@ class ColumnParallelLinear(LinearBase):
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
@@ -1469,10 +1464,9 @@ class RowParallelLinear(LinearBase):
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user