[Bugfix][CI] Fix Marlin FP8 Linear Kernel for Compressed Tensors Format (#38092)
Signed-off-by: BadrBasowid <Badr.Basowid@gmail.com> Signed-off-by: BadrBasowid <61441185+BadrBasowid@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -76,8 +76,25 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
else:
|
||||
weight = layer.weight.t()
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
w_q, *_ = self._get_layer_params(layer)
|
||||
# Compressed tensors transposes the weight to (K, N)
|
||||
# for channel and tensor quant strategies.
|
||||
# So we can skip the transpose if the layout is
|
||||
# already (K, N).
|
||||
# TODO: Remove this check once the layouts have been
|
||||
# canonicalized to a standard (N, K) dimension. See issue
|
||||
# #33314 for more details.
|
||||
if w_q.shape != (
|
||||
layer.input_size_per_partition,
|
||||
layer.output_size_per_partition,
|
||||
):
|
||||
# transpose the weights to (K,N)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"weight",
|
||||
w_q.t(),
|
||||
)
|
||||
|
||||
layer.input_scale = None
|
||||
prepare_fp8_layer_for_marlin(
|
||||
layer, self.size_k_first, input_dtype=self.marlin_input_dtype
|
||||
|
||||
@@ -188,6 +188,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
if hasattr(self, "fp8_linear"):
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@@ -705,6 +705,9 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
|
||||
layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)
|
||||
|
||||
if hasattr(self, "fp8_linear"):
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Reference in New Issue
Block a user