diff --git a/CMakeLists.txt b/CMakeLists.txt index ec67ee8c3..97e96e997 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -377,7 +377,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # preselected input type pairs and schedules. # Generate sources: set(MARLIN_GEN_SCRIPT - ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/marlin/generate_kernels.py) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") @@ -412,7 +412,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if (MARLIN_ARCHS) - file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_float16.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") @@ -422,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) - file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_bfloat16.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_BF16_ARCHS}") @@ -434,7 +434,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if (MARLIN_SM75_ARCHS) - file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu") + file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/marlin/sm75_kernel_*.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_SM75_ARCHS}") @@ -446,7 +446,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if (MARLIN_FP8_ARCHS) - file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") + file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/marlin/sm89_kernel_*.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_FP8_ARCHS}") @@ -459,10 +459,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MARLIN_SRCS "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/gptq_marlin/gptq_marlin.cu" - "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" - "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") + "csrc/quantization/marlin/marlin.cu" + "csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu" + "csrc/quantization/marlin/gptq_marlin_repack.cu" + "csrc/quantization/marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" CUDA_ARCHS "${MARLIN_OTHER_ARCHS}") diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index ac78c019a..4e6f09866 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -231,7 +231,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: assert bt.w_tok_s is None assert bt.group_size is not None - fn = lambda: ops.gptq_marlin_gemm( + fn = lambda: ops.marlin_gemm( a=bt.a, c=None, b_q_weight=w_q, diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 48d790aec..0b79141d6 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -239,7 +239,7 @@ def bench_run( "sm_version": sm_version, "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, # Kernels - "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "marlin_gemm": ops.marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, @@ -263,21 +263,21 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="gptq_marlin_gemm", + description="marlin_gemm", ).blocked_autorange(min_run_time=min_run_time) ) results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="gptq_marlin_gemm_fp32", + description="marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time) ) diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index eb83df22c..e5a3a0b9c 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -3,8 +3,8 @@ #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif -#include "quantization/gptq_marlin/marlin.cuh" -#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "quantization/marlin/marlin.cuh" +#include "quantization/marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" #define MARLIN_KERNEL_PARAMS \ diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 5aac69b5c..cddc42643 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -23,10 +23,10 @@ #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif -#include "quantization/gptq_marlin/marlin.cuh" -#include "quantization/gptq_marlin/marlin_dtypes.cuh" -#include "quantization/gptq_marlin/dequant.h" -#include "quantization/gptq_marlin/marlin_mma.h" +#include "quantization/marlin/marlin.cuh" +#include "quantization/marlin/marlin_dtypes.cuh" +#include "quantization/marlin/dequant.h" +#include "quantization/marlin/marlin_mma.h" #include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh index 14a61ad8f..c7a6e96af 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -7,7 +7,7 @@ #include #include #include -#include "../gptq_marlin/marlin_dtypes.cuh" +#include "../marlin/marlin_dtypes.cuh" using marlin::MarlinScalarType2; namespace allspark { diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/marlin/.gitignore similarity index 100% rename from csrc/quantization/gptq_marlin/.gitignore rename to csrc/quantization/marlin/.gitignore diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/marlin/awq_marlin_repack.cu similarity index 100% rename from csrc/quantization/gptq_marlin/awq_marlin_repack.cu rename to csrc/quantization/marlin/awq_marlin_repack.cu diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/marlin/dequant.h similarity index 100% rename from csrc/quantization/gptq_marlin/dequant.h rename to csrc/quantization/marlin/dequant.h diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/marlin/generate_kernels.py similarity index 100% rename from csrc/quantization/gptq_marlin/generate_kernels.py rename to csrc/quantization/marlin/generate_kernels.py diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/marlin/gptq_marlin_repack.cu similarity index 100% rename from csrc/quantization/gptq_marlin/gptq_marlin_repack.cu rename to csrc/quantization/marlin/gptq_marlin_repack.cu diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/marlin/kernel.h similarity index 100% rename from csrc/quantization/gptq_marlin/kernel.h rename to csrc/quantization/marlin/kernel.h diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/marlin/marlin.cu similarity index 99% rename from csrc/quantization/gptq_marlin/gptq_marlin.cu rename to csrc/quantization/marlin/marlin.cu index 77f319d53..62826128c 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/marlin/marlin.cu @@ -46,7 +46,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } // namespace marlin -torch::Tensor gptq_marlin_gemm( +torch::Tensor marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, @@ -528,7 +528,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } // namespace marlin -torch::Tensor gptq_marlin_gemm( +torch::Tensor marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, @@ -856,5 +856,5 @@ torch::Tensor gptq_marlin_gemm( #endif TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("gptq_marlin_gemm", &gptq_marlin_gemm); + m.impl("marlin_gemm", &marlin_gemm); } diff --git a/csrc/quantization/gptq_marlin/marlin.cuh b/csrc/quantization/marlin/marlin.cuh similarity index 100% rename from csrc/quantization/gptq_marlin/marlin.cuh rename to csrc/quantization/marlin/marlin.cuh diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/marlin/marlin_dtypes.cuh similarity index 100% rename from csrc/quantization/gptq_marlin/marlin_dtypes.cuh rename to csrc/quantization/marlin/marlin_dtypes.cuh diff --git a/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu b/csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu similarity index 100% rename from csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu rename to csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu diff --git a/csrc/quantization/gptq_marlin/marlin_mma.h b/csrc/quantization/marlin/marlin_mma.h similarity index 100% rename from csrc/quantization/gptq_marlin/marlin_mma.h rename to csrc/quantization/marlin/marlin_mma.h diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/marlin/marlin_template.h similarity index 100% rename from csrc/quantization/gptq_marlin/marlin_template.h rename to csrc/quantization/marlin/marlin_template.h diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 28f789b65..68257bdda 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -303,9 +303,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); ops.impl("permute_cols", torch::kCUDA, &permute_cols); - // gptq_marlin Optimized Quantized GEMM for GPTQ. + // Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4). ops.def( - "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " + "marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " "Tensor? b_bias_or_none,Tensor b_scales, " "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, " "Tensor? " diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index d3826468c..6b3d14da2 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -59,7 +59,7 @@ if current_platform.is_rocm(): pytest.skip( "These tests require gptq_marlin_repack," "marlin_int4_fp8_preprocess, gptq_marlin_24_gemm," - "or gptq_marlin_gemm which are not supported on ROCm.", + "or marlin_gemm which are not supported on ROCm.", allow_module_level=True, ) @@ -417,7 +417,7 @@ def marlin_generate_valid_test_cases(): ), marlin_generate_valid_test_cases(), ) -def test_gptq_marlin_gemm( +def test_marlin_gemm( a_type, b_type, c_type, @@ -511,7 +511,7 @@ def test_gptq_marlin_gemm( output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device) - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( a_input, output, marlin_q_w, @@ -646,7 +646,7 @@ def test_marlin_gemm_subset_input(): marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( a_input, None, marlin_q_w, @@ -695,7 +695,7 @@ def test_marlin_gemm_with_bias(size_m): marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( a_input, None, marlin_q_w, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7cce82073..ff63aef38 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -591,8 +591,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::gptq_marlin_gemm") - def _gptq_marlin_gemm_fake( + @register_fake("_C::marlin_gemm") + def _marlin_gemm_fake( a: torch.Tensor, c: torch.Tensor | None, b_q_weight: torch.Tensor, @@ -1312,7 +1312,7 @@ def marlin_int4_fp8_preprocess( return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace) -def gptq_marlin_gemm( +def marlin_gemm( a: torch.Tensor, c: torch.Tensor | None, b_q_weight: torch.Tensor, @@ -1333,7 +1333,7 @@ def gptq_marlin_gemm( use_fp32_reduce: bool = False, is_zp_float: bool = False, ) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm( + return torch.ops._C.marlin_gemm( a, c, b_q_weight, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index d167452b1..7fa850c85 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -563,7 +563,7 @@ def apply_gptq_marlin_linear( reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( reshaped_x, None, weight, @@ -628,7 +628,7 @@ def apply_awq_marlin_linear( ) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( reshaped_x, None, weight, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index db56b84a9..789ed5dba 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -121,7 +121,7 @@ def apply_fp4_marlin_linear( inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( a=inputs, c=None, b_q_weight=weight, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5be688265..b5a557ce9 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -66,7 +66,7 @@ def apply_fp8_marlin_linear( # inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) raise RuntimeError("Marlin W8A8 is not supported.") - output = ops.gptq_marlin_gemm( + output = ops.marlin_gemm( a=inputs, c=None, b_q_weight=weight,