[BUG] fixed fp8 conflict with aqlm (#4307)

Fixes fp8 iterface which broke in AQLM merge.
This commit is contained in:
Robert Shaw
2024-04-23 21:26:33 -04:00
committed by GitHub
parent eace8bf0b9
commit 79a268c4ab
3 changed files with 18 additions and 4 deletions

View File

@@ -34,9 +34,19 @@ class LinearMethodBase(ABC):
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer."""
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod