[Quantization] [Performance] Enable Marlin GEMM kernels for the calibration-free RTN-based quantization (#26051)

Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
Signed-off-by: Alex Kogan <82225080+sakogan@users.noreply.github.com>
This commit is contained in:
Alex Kogan
2025-10-13 14:52:54 -04:00
committed by GitHub
parent f89f599395
commit 89342ce4c0
2 changed files with 233 additions and 56 deletions

View File

@@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
)
return output.reshape(out_shape)
def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)