[BUG] fixed fp8 conflict with aqlm (#4307)
Fixes fp8 iterface which broke in AQLM merge.
This commit is contained in:
@@ -64,12 +64,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
|
||||
Reference in New Issue
Block a user