[Refactor] Rename gptq_marlin to marlin to match MoE (#32952)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -591,8 +591,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::gptq_marlin_gemm")
|
||||
def _gptq_marlin_gemm_fake(
|
||||
@register_fake("_C::marlin_gemm")
|
||||
def _marlin_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
@@ -1312,7 +1312,7 @@ def marlin_int4_fp8_preprocess(
|
||||
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
|
||||
|
||||
|
||||
def gptq_marlin_gemm(
|
||||
def marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
@@ -1333,7 +1333,7 @@ def gptq_marlin_gemm(
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(
|
||||
return torch.ops._C.marlin_gemm(
|
||||
a,
|
||||
c,
|
||||
b_q_weight,
|
||||
|
||||
@@ -563,7 +563,7 @@ def apply_gptq_marlin_linear(
|
||||
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
@@ -628,7 +628,7 @@ def apply_awq_marlin_linear(
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
|
||||
@@ -121,7 +121,7 @@ def apply_fp4_marlin_linear(
|
||||
|
||||
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a=inputs,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
|
||||
@@ -66,7 +66,7 @@ def apply_fp8_marlin_linear(
|
||||
# inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||
raise RuntimeError("Marlin W8A8 is not supported.")
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a=inputs,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
|
||||
Reference in New Issue
Block a user