Llamas 3.1 405B fp4 changes upstreaming from 355_wip (#25135)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com>
This commit is contained in:
committed by
GitHub
parent
8b77328ffe
commit
53a30845be
@@ -323,6 +323,12 @@ class ReplicatedLinear(LinearBase):
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# If MergedReplicatedLinear, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = self.output_sizes
|
||||
else:
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
@@ -335,7 +341,8 @@ class ReplicatedLinear(LinearBase):
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_partition_sizes,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
@@ -374,12 +381,15 @@ class ReplicatedLinear(LinearBase):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
@@ -413,7 +423,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
@@ -535,13 +545,15 @@ class ColumnParallelLinear(LinearBase):
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(
|
||||
self, input_
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
|
||||
if self.gather_output and self.tp_size > 1:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
@@ -1326,7 +1338,8 @@ class RowParallelLinear(LinearBase):
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(
|
||||
self, input_
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
@@ -1340,9 +1353,8 @@ class RowParallelLinear(LinearBase):
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user