[BUG] fixed fp8 conflict with aqlm (#4307)

Fixes fp8 iterface which broke in AQLM merge.
This commit is contained in:
Robert Shaw
2024-04-23 21:26:33 -04:00
committed by GitHub
parent eace8bf0b9
commit 79a268c4ab
3 changed files with 18 additions and 4 deletions

View File

@@ -64,12 +64,13 @@ class Fp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),