[Quantization] [Performance] Enable Marlin GEMM kernels for the calibration-free RTN-based quantization (#26051)
Signed-off-by: Alex Kogan <alex.kogan@oracle.com> Signed-off-by: Alex Kogan <82225080+sakogan@users.noreply.github.com>
This commit is contained in:
@@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
|
||||
)
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def apply_rtn_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(
|
||||
m=reshaped_x.size(0),
|
||||
n=output_size_per_partition,
|
||||
k=reshaped_x.size(1),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
workspace,
|
||||
quant_type,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
Reference in New Issue
Block a user