[Bugfix][CI] Fix Marlin FP8 Linear Kernel for Compressed Tensors Format (#38092)

Signed-off-by: BadrBasowid <Badr.Basowid@gmail.com>
Signed-off-by: BadrBasowid <61441185+BadrBasowid@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
BadrBasowid
2026-03-26 12:11:43 +08:00
committed by GitHub
parent 144030c84e
commit e6bf9f15ec
3 changed files with 25 additions and 2 deletions

View File

@@ -188,6 +188,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer)
if hasattr(self, "fp8_linear"):
self.fp8_linear.process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,

View File

@@ -705,6 +705,9 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)
if hasattr(self, "fp8_linear"):
self.fp8_linear.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,