diff --git a/CMakeLists.txt b/CMakeLists.txt index 97e96e997..0000b6d32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -458,7 +458,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() set(MARLIN_SRCS - "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/marlin.cu" "csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu" "csrc/quantization/marlin/gptq_marlin_repack.cu" diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 0b79141d6..c0019a51c 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -6,12 +6,6 @@ import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, - GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, - GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, -) from vllm.model_executor.layers.quantization.utils.allspark_utils import ( ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES, @@ -34,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize, marlin_quantize, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, gptq_quantize_weights, @@ -78,14 +69,7 @@ def bench_run( if size_k % group_size != 0: return - marlin_24_supported = ( - quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ) - repack_supported = ( - quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in MARLIN_SUPPORTED_GROUP_SIZES - ) + repack_supported = group_size in MARLIN_SUPPORTED_GROUP_SIZES allspark_supported = ( quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES and group_size == -1 @@ -126,14 +110,6 @@ def bench_run( marlin_sort_indices, ) - def gen_marlin_24_params(): - marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None - if marlin_24_supported: - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( - marlin_24_quantize(b, quant_type, group_size) - ) - return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) - def gen_repack_params(): q_w_gptq = None repack_sort_indices = None @@ -188,9 +164,6 @@ def bench_run( marlin_g_idx, marlin_sort_indices, ) = gen_marlin_params() - marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = ( - gen_marlin_24_params() - ) q_w_gptq, repack_sort_indices = gen_repack_params() qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = ( gen_allspark_params() @@ -200,9 +173,6 @@ def bench_run( marlin_workspace = MarlinWorkspace( size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL ) - marlin_24_workspace = MarlinWorkspace( - size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL - ) globals = { # Gen params @@ -222,12 +192,6 @@ def bench_run( "marlin_sort_indices": marlin_sort_indices, "marlin_workspace": marlin_workspace, "is_k_full": is_k_full, - # Marlin_24 params - "marlin_24_w_ref": marlin_24_w_ref, - "marlin_24_q_w_comp": marlin_24_q_w_comp, - "marlin_24_meta": marlin_24_meta, - "marlin_24_s": marlin_24_s, - "marlin_24_workspace": marlin_24_workspace, # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, @@ -240,7 +204,6 @@ def bench_run( "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, # Kernels "marlin_gemm": ops.marlin_gemm, - "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, } @@ -281,17 +244,6 @@ def bench_run( ).blocked_autorange(min_run_time=min_run_time) ) - if marlin_24_supported: - results.append( - benchmark.Timer( - stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_24_gemm", - ).blocked_autorange(min_run_time=min_run_time) - ) - if repack_supported: results.append( benchmark.Timer( diff --git a/csrc/quantization/marlin/sparse/LICENSE b/csrc/quantization/marlin/sparse/LICENSE deleted file mode 100644 index ca75fb15e..000000000 --- a/csrc/quantization/marlin/sparse/LICENSE +++ /dev/null @@ -1,203 +0,0 @@ -Contains code from https://github.com/IST-DASLab/Sparse-Marlin/ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h deleted file mode 100644 index 16018d331..000000000 --- a/csrc/quantization/marlin/sparse/common/base.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace marlin_24 { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -template -struct ShapeBase { - static constexpr int M = M_, N = N_, K = K_; -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragM = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h deleted file mode 100644 index 83e3578d2..000000000 --- a/csrc/quantization/marlin/sparse/common/mem.h +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" - -namespace marlin_24 { -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void* smem_ptr, - const void* glob_ptr, - bool pred = true, - const bool zfill = false) { - const int BYTES = 16; - int src_in_bytes = (zfill ? 0 : BYTES); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); -} - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_m); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} -} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h deleted file mode 100644 index b26505f77..000000000 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" -#include - -namespace marlin_24 { - -// On CUDA earlier than 12.5, the ordered_metadata version of this instruction -// is not supported. On later versions of CUDA the version without ordered -// metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially -// | reduced performance on some future architectures -#if defined CUDA_VERSION && CUDA_VERSION >= 12050 - #define MMA_SP_INST \ - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#else - #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#endif - -// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, - const FragA& frag_b, FragC& frag_c, FragM& frag_m, - const int psel) { - const uint32_t* a0 = reinterpret_cast(&a_frag0); - const uint32_t* a1 = reinterpret_cast(&a_frag1); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* e = reinterpret_cast(&frag_m); - - float* c = reinterpret_cast(&frag_c); - if (psel == 0) { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } else { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, - float c3) { - uint2 r; - asm("{\n\t" - ".reg .f16 a, b, c, d; \n\t" - "cvt.rn.f16.f32 a, %2; \n\t" - "cvt.rn.f16.f32 b, %3; \n\t" - "cvt.rn.f16.f32 c, %4; \n\t" - "cvt.rn.f16.f32 d, %5; \n\t" - "mov.b32 %0, {a, b}; \n\t" - "mov.b32 %1, {c, d}; \n\t" - "}" - : "=r"(r.x), "=r"(r.y) - : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - return r; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, - FragS& s0, float* c4, float* c5, float* c6, - float* c7, FragS& s1) { - *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); - *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); - *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); - *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); - - *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); - *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); - *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); - *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); -} - -} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu deleted file mode 100644 index c33e71ae5..000000000 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ /dev/null @@ -1,1145 +0,0 @@ -/* - * Notice: This file was modified by Neuralmagic inc to include 8-bit support - * - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" -#include "core/scalar_type.hpp" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -#else - - #include "common/mem.h" - #include "common/mma.h" - -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_24 { - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int THREADS = 256; -static constexpr int STAGES = 4; - -static constexpr int min_thread_n = 128; - -static constexpr int tile_size = 16; -static constexpr int max_par = 64; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, - int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - // number of thread_k_blocks in k-dim - int k_tiles = prob_k / 32 / thread_k_blocks; - // number of thread_n_blocks in n-dim - int n_tiles = prob_n / 16 / thread_n_blocks; - // iters needed to cover all slices - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - // number of threadblock tiles in the current slice - int slice_iters; - // total number of active threadblocks in the current slice - int slice_count = 0; - // index of threadblock in current slice; numbered bottom to top - int slice_idx; - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 32 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads //RLC: 2 * #warps k-dim - constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - constexpr int pack_factor = 32 / num_bits; - - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 - constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp - int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; - int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); - constexpr int m_sh_wr_delta = threads / 2; - constexpr int m_sh_rd_delta = threads / 2; - constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; - constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; - auto b_sh_rd = threadIdx.x * b_thread_vecs; - - int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + - (threadIdx.x % (m_sh_stride)); - m_gl_rd += (m_sh_stride)*slice_col; - m_gl_rd += m_gl_rd_delta_o * slice_row; - auto m_sh_wr = threadIdx.x; - auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; - - int s_gl_rd; - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - - auto s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; // Note that in the original Marlin kernel - // this is (threadIdx.x % 32) / 4 - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) { - a_sh_rd_trans[0][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - a_sh_rd_trans[1][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); - } - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4* meta_ptr[m_sh_iters]; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_m = sh_s + (stages * s_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks][2]; - I4 frag_b_quant[2][b_thread_vecs]; - FragM frag_m[2][2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - int4* sh_meta_stage = sh_m + m_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) { - if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); - meta_ptr[i] += m_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - ldsm4(frag_a[k % 2][i][0], - &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); - ldsm4(frag_a[k % 2][i][1], - &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); - } - - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - - // Load meta with ldsm4 - int4* sh_m_stage = sh_m + m_sh_stage * pipe; - ldsm4_m(frag_m[k % 2][0], - &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], - frag_m[k % 2][j / 2], j % 2); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; - int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) - int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int col = 2 * ((threadIdx.x % 32) % 4); - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2] += - __half2float( - reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); - } - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2]); - } - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: - - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - - int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) - - constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: - int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + - (threadIdx.x % (2 * 2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, - float c4, float c5, float c6, float c7, FragS& s1) { - uint2 res[2]; - res[0] = to_half4(c0, c1, c2, c3); - res[1] = to_half4(c4, c5, c6, c7); - half2* tmp = (half2*)&res; - // for per-column quantization we finally apply the scale here - if constexpr (group_blocks == -1 && num_bits == 4) { - tmp[0] = __hmul2(tmp[0], s0[0]); - tmp[1] = __hmul2(tmp[1], s0[1]); - tmp[2] = __hmul2(tmp[2], s1[0]); - tmp[3] = __hmul2(tmp[3], s1[1]); - } - ((int4*)sh)[idx] = *((int4*)&res[0]); - }; - - // RLC: only warp 0 and 1 baseline example - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - int wr = c_sh_wr; - write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], - frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], - frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], - frag_s[0][2]); - write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], - frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], - frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], - frag_c[i][3][0][3], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], - frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], - frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], - frag_c[i][3][1][2], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], - frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], - frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], - frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); - - c_sh_wr += 8 * c_sh_stride_2; - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - matmul(pipe); - wait_for_stage(); - - fetch_to_registers(pipe + 1, (pipe + 1) % stages); - - pipe++; - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - } - } - thread_block_reduce(); - - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], - &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], - &frag_c[i][0][0][2], &frag_c[i][1][0][2], - &frag_c[i][2][0][2], &frag_c[i][3][0][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], - &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], - &frag_c[i][0][0][3], &frag_c[i][1][0][3], - &frag_c[i][2][0][3], &frag_c[i][3][0][3], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], - &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], - &frag_c[i][0][1][2], &frag_c[i][1][1][2], - &frag_c[i][2][1][2], &frag_c[i][3][1][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], - &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], - &frag_c[i][0][1][3], &frag_c[i][1][1][3], - &frag_c[i][2][1][3], &frag_c[i][3][1][3], - frag_s[0][2]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#endif - -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ - } - -void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, - void* s, int prob_m, int prob_n, int prob_k, - void* workspace, int num_bits, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_m = -1, int sms = -1, int max_par = 16) { - int tot_n = prob_n; - int tot_n_blocks = ceildiv(tot_n, 16); - int pad = 16 * tot_n_blocks - tot_n; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - TORCH_CHECK(sms > 0); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (thread_k == -1 || thread_m == -1) { - if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important - // than better compute utilization - thread_k = 128; - thread_m = 128; - } else { - thread_k = 64; - thread_m = 256; - } - // Also had - // if prob_n > 256 - // thread_k = 32; - // thread_m = 512; - // but this is broken, - // TODO(Lucas, Alex M): figure out why - } - - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction - int thread_m_blocks = thread_m / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, - " is not divisible by thread_m = ", thread_m); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, - " is not divisible by group_blocks = ", group_blocks); - } - - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - const int4* meta_ptr = (const int4*)meta; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - constexpr int max_m_blocks = 4; - - int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { - int thread_n_blocks = tot_n_blocks - i; - prob_n = tot_n - 16 * i; - int par = 1; - if (thread_n_blocks > max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); - if (par > max_par) par = max_par; - prob_n = (max_m_blocks * 16) * par; - i += max_m_blocks * (par - 1); - thread_n_blocks = max_m_blocks; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - - // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group - // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 16, 2, 2, 4) - CALL_IF_2_4(4, 16, 3, 2, -1) - CALL_IF_2_4(4, 16, 3, 2, 4) - CALL_IF_2_4(4, 16, 4, 2, -1) - CALL_IF_2_4(4, 16, 4, 2, 4) - - CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 32, 2, 1, 4) - CALL_IF_2_4(4, 32, 3, 1, -1) - CALL_IF_2_4(4, 32, 3, 1, 4) - CALL_IF_2_4(4, 32, 4, 1, -1) - CALL_IF_2_4(4, 32, 4, 1, 4) - - // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 16, 2, 2, 4) - CALL_IF_2_4(8, 16, 3, 2, -1) - CALL_IF_2_4(8, 16, 3, 2, 4) - CALL_IF_2_4(8, 16, 4, 2, -1) - CALL_IF_2_4(8, 16, 4, 2, 4) - - CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 32, 2, 1, 4) - CALL_IF_2_4(8, 32, 3, 1, -1) - CALL_IF_2_4(8, 32, 3, 1, 4) - CALL_IF_2_4(8, 32, 4, 1, -1) - CALL_IF_2_4(8, 32, 4, 1, 4) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; - } -} - -} // namespace marlin_24 - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, - int64_t size_m, int64_t size_n, - int64_t size_k) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - // Verify num_bits - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str()); - int pack_factor = 32 / b_q_type.size_bits(); - - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_24::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_24::tile_size)); - TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_24::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_24::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_24::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify meta - TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, - "b_meta.size(0) = ", b_meta.size(0), - " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); - TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), - " is not size_n * 2 = ", size_n * 2); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - TORCH_CHECK(a.dtype() == torch::kFloat16, - "A is not float16, currently only float16 is supported"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify b_meta device and strides - TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); - TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - TORCH_CHECK(b_scales.dtype() == torch::kFloat16, - "A is not float16, currently only float16 is supported"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - int thread_k = -1; - int thread_m = -1; - int sms = -1; - int max_par = marlin_24::max_par; - - int groupsize = -1; - if (b_scales.size(0) > 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 - } - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 64, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_24::min_thread_n)); - int min_workspace_size = - (size_n / marlin_24::min_thread_n) * marlin_24::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_24::marlin_cuda_2_4( - a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_m, sms, max_par); - - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index cdaf873a1..97c0e80e7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -259,14 +259,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // custom types: // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA - // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. - ops.def( - "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " - "Tensor b_scales, Tensor workspace, " - "int b_q_type, " - "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor"); - // conditionally compiled so impl in source file - // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. ops.def( "machete_supported_schedules(" diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 209a879bf..ed4c92d90 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -58,17 +58,6 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): ) ) - if is_quant_method_supported("gptq_marlin_24"): - TEST_MODELS.append( - ( - "alexm-nm/tinyllama-24-marlin24-4bit-g128", - { - "quantization": "gptq_marlin_24", - "allow_deprecated_quantization": True, - }, - ) - ) - if not current_platform.is_rocm() and is_quant_method_supported("awq"): TEST_MODELS.append( ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"}) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 6b3d14da2..3453753ec 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -10,15 +10,9 @@ import itertools import pytest import torch -from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck +from tests.kernels.utils import opcheck from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, - GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, - GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, -) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_quant_int8, ) @@ -36,15 +30,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, marlin_weights, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, @@ -57,9 +47,7 @@ from vllm.scalar_type import scalar_types if current_platform.is_rocm(): pytest.skip( - "These tests require gptq_marlin_repack," - "marlin_int4_fp8_preprocess, gptq_marlin_24_gemm," - "or marlin_gemm which are not supported on ROCm.", + "These tests require marlin, which is not supported on ROCm.", allow_module_level=True, ) @@ -71,9 +59,6 @@ USE_FP32_REDUCE_OPTS = [True] MARLIN_K_CHUNKS = [128] MARLIN_N_CHUNKS = [64, 256] -MARLIN_24_K_CHUNKS = [128] -MARLIN_24_N_CHUNKS = [512] - MARLIN_REPACK_NK_FACTORS = [ (4, 8), (7, 5), @@ -538,96 +523,6 @@ def test_marlin_gemm( assert max_diff < 0.04 -# TODO: find better way to test this? -@torch.compile(fullgraph=True) -def marlin_24_gemm_tester( - a_input, - marlin_24_q_w_comp, - marlin_24_meta, - marlin_24_s, - scratch, - quant_type, - size_m, - size_n, - size_k, -): - return ops.gptq_marlin_24_gemm( - a_input, - marlin_24_q_w_comp, - marlin_24_meta, - marlin_24_s, - scratch, - quant_type, - size_m, - size_n, - size_k, - ) - - -@pytest.mark.skipif( - not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.", -) -@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) -@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) -@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize( - b_weight, quant_type, group_size - ) - - workspace_24 = MarlinWorkspace( - size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL - ) - - output_ref = torch.matmul(a_input, w_24_ref) - - opcheck( - torch.ops._C.gptq_marlin_24_gemm, - ( - a_input, - marlin_24_q_w_comp, - marlin_24_meta, - marlin_24_s, - workspace_24.scratch, - quant_type.id, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - ), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - ) - - output = marlin_24_gemm_tester( - a_input, - marlin_24_q_w_comp, - marlin_24_meta, - marlin_24_s, - workspace_24.scratch, - quant_type, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - ) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - def test_marlin_gemm_subset_input(): quant_type = scalar_types.uint4b8 group_size = 128 diff --git a/tests/models/quantization/test_gptq_marlin_24.py b/tests/models/quantization/test_gptq_marlin_24.py deleted file mode 100644 index 43d1d35fa..000000000 --- a/tests/models/quantization/test_gptq_marlin_24.py +++ /dev/null @@ -1,87 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the outputs of a GPTQ model to a Marlin_24 model. - -Note: GPTQ and Marlin_24 do not have bitwise correctness. -As a result, in this test, we just confirm that the top selected tokens of the -Marlin/GPTQ models are in the top 3 selections of each other. -""" - -from dataclasses import dataclass - -import pytest - -from tests.quantization.utils import is_quant_method_supported -from vllm.platforms import current_platform - -from ..utils import check_logprobs_close - - -@dataclass -class ModelPair: - model_marlin: str - model_gptq: str - - -model_pairs = [ - # 4-bit, group_size == 128 - ModelPair( - model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", - model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128", - ), - # # 4-bit, group_size == channelwise - # ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise", - # model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"), - # 8-bit, group_size == 128 - ModelPair( - model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", - model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128", - ), - # # 8-bit, group_size == channelwise - # ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise", - # model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"), -] - - -@pytest.mark.flaky(reruns=2) -@pytest.mark.skipif( - not is_quant_method_supported("gptq_marlin_24") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="Marlin24 is not supported on this GPU type.", -) -@pytest.mark.parametrize("model_pair", model_pairs) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [8]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - vllm_runner, - example_prompts, - model_pair: ModelPair, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - with vllm_runner( - model_pair.model_marlin, - dtype=dtype, - quantization="gptq_marlin_24", - allow_deprecated_quantization=True, - ) as marlin_24_model: - marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs - ) - - with vllm_runner( - model_pair.model_gptq, dtype=dtype, quantization="gptq" - ) as gptq_model: - gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs - ) - - check_logprobs_close( - outputs_0_lst=gptq_outputs, - outputs_1_lst=marlin_24_outputs, - name_0="gptq", - name_1="marlin_24", - ) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index f447777d5..d9d7f5e2f 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -17,7 +17,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, @@ -307,28 +306,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): assert output -@pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." -) -def test_compressed_tensors_w4a16_marlin24(vllm_runner): - model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" - with vllm_runner(model_path, enforce_eager=True) as llm: - - def check_model(model): - layer = model.model.layers[0] - - qkv_proj = layer.self_attn.qkv_proj - - assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24) - assert qkv_proj.weight_packed.dtype is torch.int32 - - llm.apply_model(check_model) - - output = llm.generate_greedy("Hello my name is", max_tokens=4) - assert output - - def test_compressed_tensors_fp8(vllm_runner): model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" with vllm_runner(model_path, enforce_eager=True) as llm: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 01067ca32..b9b7b872f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -499,6 +499,23 @@ def awq_dequantize( return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) +if hasattr(torch.ops._C, "awq_dequantize"): + + @register_fake("_C::awq_dequantize") + def _awq_dequantize_fake( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: torch.SymInt, + thx: int, + thy: int, + ) -> torch.Tensor: + in_c = qweight.size(0) + qout_c = qweight.size(1) + out_c = qout_c * 8 + return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) + + def awq_gemm( input: torch.Tensor, qweight: torch.Tensor, @@ -513,6 +530,24 @@ def awq_gemm( return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) +if hasattr(torch.ops._C, "awq_gemm"): + + @register_fake("_C::awq_gemm") + def _awq_gemm_fake( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: torch.SymInt, + ) -> torch.Tensor: + num_in_feats = input.size(0) + return torch.empty( + (split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device, + ).sum(0) + + # gptq def gptq_gemm( a: torch.Tensor, @@ -558,152 +593,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return torch.ops._C.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - - @register_fake("_C::gptq_marlin_24_gemm") - def _gptq_marlin_24_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("_C::marlin_gemm") - def _marlin_gemm_fake( - a: torch.Tensor, - c: torch.Tensor | None, - b_q_weight: torch.Tensor, - b_bias: torch.Tensor | None, - b_scales: torch.Tensor, - a_scales: torch.Tensor | None, - global_scale: torch.Tensor | None, - b_zeros: torch.Tensor | None, - g_idx: torch.Tensor | None, - perm: torch.Tensor | None, - workspace: torch.Tensor, - b_q_type_id: int, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, - ) -> torch.Tensor: - dtype = a.dtype - if dtype not in [torch.half, torch.bfloat16]: - dtype = b_scales.dtype - return torch.empty((size_m, size_n), device=a.device, dtype=dtype) - - @register_fake("_C::awq_dequantize") - def _awq_dequantize_fake( - qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - split_k_iters: torch.SymInt, - thx: int, - thy: int, - ) -> torch.Tensor: - in_c = qweight.size(0) - qout_c = qweight.size(1) - out_c = qout_c * 8 - return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) - - @register_fake("_C::awq_gemm") - def _awq_gemm_fake( - input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - split_k_iters: torch.SymInt, - ) -> torch.Tensor: - num_in_feats = input.size(0) - return torch.empty( - (split_k_iters, num_in_feats, qweight.size(1) * 8), - dtype=input.dtype, - device=input.device, - ).sum(0) - - @register_fake("_C::machete_mm") - def machete_mm_fake( - a: torch.Tensor, - # b_q Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type: ScalarType, - out_type: torch.dtype | None = None, - b_group_scales: torch.Tensor | None = None, - b_group_zeros: torch.Tensor | None = None, - b_group_size: int | None = None, - b_channel_scales: torch.Tensor | None = None, - a_token_scales: torch.Tensor | None = None, - schedule: str | None = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - return torch.empty((m, n), device=a.device, dtype=a.dtype) - - @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake( - b_q_weight: torch.Tensor, - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: torch.dtype | None, - ) -> torch.Tensor: - return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - - @register_fake("_C::cutlass_w4a8_mm") - def cutlass_w4a8_mm_fake( - a: torch.Tensor, - # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b - b_q: torch.Tensor, - b_group_scales: torch.Tensor, - b_group_size: int, - b_channel_scales: torch.Tensor, - a_token_scales: torch.Tensor, - out_type: torch.dtype | None = None, - maybe_schedule: str | None = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - out_dtype = out_type if out_type is not None else torch.bfloat16 - return torch.empty((m, n), device=a.device, dtype=out_dtype) - - @register_fake("_C::cutlass_pack_scale_fp8") - def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: - return torch.empty_like(scales, memory_format=torch.contiguous_format) - - @register_fake("_C::cutlass_encode_and_reorder_int4b") - def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: - return torch.empty_like(b, memory_format=torch.contiguous_format) - - @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped") - def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor: - return torch.empty_like(b, memory_format=torch.contiguous_format) - - if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @register_fake("_C::allspark_w8a16_gemm") @@ -1356,6 +1245,36 @@ def marlin_gemm( ) +if hasattr(torch.ops._C, "marlin_gemm"): + + @register_fake("_C::marlin_gemm") + def _marlin_gemm_fake( + a: torch.Tensor, + c: torch.Tensor | None, + b_q_weight: torch.Tensor, + b_bias: torch.Tensor | None, + b_scales: torch.Tensor, + a_scales: torch.Tensor | None, + global_scale: torch.Tensor | None, + b_zeros: torch.Tensor | None, + g_idx: torch.Tensor | None, + perm: torch.Tensor | None, + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, + ) -> torch.Tensor: + dtype = a.dtype + if dtype not in [torch.half, torch.bfloat16]: + dtype = b_scales.dtype + return torch.empty((size_m, size_n), device=a.device, dtype=dtype) + + # machete def machete_supported_schedules( a_type: torch.dtype, @@ -1404,6 +1323,27 @@ def machete_mm( ) +if hasattr(torch.ops._C, "machete_mm"): + + @register_fake("_C::machete_mm") + def machete_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: torch.dtype | None = None, + b_group_scales: torch.Tensor | None = None, + b_group_zeros: torch.Tensor | None = None, + b_group_size: int | None = None, + b_channel_scales: torch.Tensor | None = None, + a_token_scales: torch.Tensor | None = None, + schedule: str | None = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + def machete_prepack_B( b_q_weight: torch.Tensor, a_type: torch.dtype, @@ -1415,6 +1355,18 @@ def machete_prepack_B( ) +if hasattr(torch.ops._C, "machete_prepack_B"): + + @register_fake("_C::machete_prepack_B") + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: torch.dtype | None, + ) -> torch.Tensor: + return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + + # CUTLASS W4A8 def cutlass_w4a8_mm( a: torch.Tensor, @@ -1439,14 +1391,48 @@ def cutlass_w4a8_mm( ) +if hasattr(torch.ops._C, "cutlass_w4a8_mm"): + + @register_fake("_C::cutlass_w4a8_mm") + def cutlass_w4a8_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: torch.dtype | None = None, + maybe_schedule: str | None = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + out_dtype = out_type if out_type is not None else torch.bfloat16 + return torch.empty((m, n), device=a.device, dtype=out_dtype) + + def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: return torch.ops._C.cutlass_pack_scale_fp8(scales) +if hasattr(torch.ops._C, "cutlass_pack_scale_fp8"): + + @register_fake("_C::cutlass_pack_scale_fp8") + def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: + return torch.empty_like(scales, memory_format=torch.contiguous_format) + + def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: return torch.ops._C.cutlass_encode_and_reorder_int4b(b) +if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b"): + + @register_fake("_C::cutlass_encode_and_reorder_int4b") + def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + + def cutlass_w4a8_moe_mm( out_tensors: torch.Tensor, a_tensors: torch.Tensor, @@ -1519,6 +1505,17 @@ def cutlass_encode_and_reorder_int4b_grouped( return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors) +if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b_grouped"): + + @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped") + def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") @@ -1526,10 +1523,6 @@ if hasattr(torch.ops._C, "permute_cols"): return torch.empty_like(a) -def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - return torch.ops._C.permute_cols(a, perm) - - # fp4 def scaled_fp4_quant( input: torch.Tensor, diff --git a/vllm/config/model.py b/vllm/config/model.py index 421e1b790..527af4c54 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -890,7 +890,6 @@ class ModelConfig: # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). overrides = [ - "gptq_marlin_24", "gptq_marlin", "awq_marlin", "ipex", diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 2582d100a..cc0fdfa8e 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,7 +18,6 @@ QuantizationMethods = Literal[ "modelopt", "modelopt_fp4", "gguf", - "gptq_marlin_24", "gptq_marlin", "awq_marlin", "gptq", @@ -41,7 +40,6 @@ DEPRECATED_QUANTIZATION_METHODS = [ "ptpc_fp8", "fbgemm_fp8", "fp_quant", - "gptq_marlin_24", "experts_int8", "ipex", "petit_nvfp4", @@ -122,7 +120,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .gguf import GGUFConfig from .gptq import GPTQConfig from .gptq_marlin import GPTQMarlinConfig - from .gptq_marlin_24 import GPTQMarlin24Config from .inc import INCConfig from .ipex_quant import IPEXConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config @@ -140,7 +137,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, "gguf": GGUFConfig, - "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5745cb547..2e61b0609 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsMoEMethod, ) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, @@ -49,7 +48,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Mxfp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, @@ -610,29 +608,19 @@ class CompressedTensorsConfig(QuantizationConfig): actorder=weight_quant.actorder, ) - if self._is_wNa16_group_channel(weight_quant, input_quant): - if ( - format == CompressionFormat.marlin_24.value - and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS - ): - assert weight_quant.symmetric - return CompressedTensorsW4A16Sparse24( - strategy=weight_quant.strategy, - num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size, - ) - if ( - format == CompressionFormat.pack_quantized.value - and weight_quant.num_bits in WNA16_SUPPORTED_BITS - ): - return CompressedTensorsWNA16( - num_bits=weight_quant.num_bits, - strategy=weight_quant.strategy, - symmetric=weight_quant.symmetric, - group_size=weight_quant.group_size, - actorder=weight_quant.actorder, - layer_name=layer_name, - ) + if ( + self._is_wNa16_group_channel(weight_quant, input_quant) + and (format == CompressionFormat.pack_quantized.value) + and (weight_quant.num_bits in WNA16_SUPPORTED_BITS) + ): + return CompressedTensorsWNA16( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder, + layer_name=layer_name, + ) act_quant_format = is_activation_quantization_format(format) if act_quant_format: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 6d40685f0..c9dd98dfd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -5,10 +5,6 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int -from .compressed_tensors_w4a16_24 import ( - W4A16SPARSE24_SUPPORTED_BITS, - CompressedTensorsW4A16Sparse24, -) from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 @@ -23,11 +19,9 @@ __all__ = [ "CompressedTensorsScheme", "CompressedTensorsWNA16", "CompressedTensorsW8A16Fp8", - "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", - "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", "CompressedTensorsW4A16Mxfp4", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py deleted file mode 100644 index dd0f4b3d8..000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ /dev/null @@ -1,176 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Callable - -import torch -from torch.nn import Parameter - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, -) -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, - GPTQ_MARLIN_24_MIN_THREAD_N, -) -from vllm.model_executor.parameter import ( - BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter, -) -from vllm.scalar_type import scalar_types - -__all__ = ["CompressedTensorsW4A16Sparse24"] -W4A16SPARSE24_SUPPORTED_TYPES_MAP = { - 4: scalar_types.uint4b8, -} -W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) - - -class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): - def __init__(self, strategy: str, num_bits: int, group_size: int | None = None): - self.strategy = strategy - self.group_size = group_size - self.tile_size = 16 - - if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: - raise ValueError( - f"Unsupported num_bits = {num_bits}. " - f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}" - ) - - self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] - - if self.strategy == "group" and self.group_size is None: - raise ValueError("group_size must be given when using strategy group") - - @classmethod - def get_min_capability(cls) -> int: - # ampere + up - return 80 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile to be torch.nn.Parameter - layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False) - layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False) - layer.meta = Parameter(layer.meta.data, requires_grad=False) - - def create_weights( - self, - layer: torch.nn.Module, - input_size: int, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): - assert params_dtype == torch.float16, ( - "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 - ) - - pack_factor = 32 // self.quant_type.size_bits - output_size_per_partition = sum(output_partition_sizes) - - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.tile_size // 2, - output_size_per_partition * self.tile_size // pack_factor, - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=pack_factor, - marlin_tile_size=self.tile_size, - weight_loader=weight_loader, - ) - - input_groups = ( - 1 - if self.group_size is None - else input_size_per_partition // self.group_size - ) - - weight_scale_args = { - "data": torch.empty( - input_groups, - output_size_per_partition, - dtype=params_dtype, - ), - "weight_loader": weight_loader, - } - - if self.group_size is not None: - scales = GroupQuantScaleParameter( - output_dim=1, input_dim=0, **weight_scale_args - ) - else: - scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) - - weight_shape = BasevLLMParameter( - data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader - ) - - meta = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - dtype=torch.int16, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=1, - marlin_tile_size=2, - weight_loader=weight_loader, - ) - - layer.register_parameter("weight_packed", qweight) - layer.register_parameter("weight_shape", weight_shape) - layer.register_parameter("scale_packed", scales) - layer.register_parameter("meta", meta) - - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N - ) * GPTQ_MARLIN_24_MAX_PARALLEL - - workspace = Parameter( - torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False - ) - layer.workspace = workspace - - def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None - ) -> torch.Tensor: - qweight = layer.weight_packed - meta = layer.meta - scales = layer.scale_packed - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.gptq_marlin_24_gemm( - x_2d, - qweight, - meta, - scales, - workspace, - self.quant_type, - size_m, - size_n, - size_k, - ) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py deleted file mode 100644 index 2fb614b47..000000000 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ /dev/null @@ -1,320 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - QuantizationMethods, -) -from vllm.model_executor.parameter import ( - BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter, -) -from vllm.scalar_type import scalar_types - -logger = init_logger(__name__) - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - - -class GPTQMarlin24Config(QuantizationConfig): - """Config class for Marlin24.""" - - def __init__( - self, - weight_bits: int, - group_size: int, - ) -> None: - super().__init__() - quant_type = { - 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128, - }.get(weight_bits) - - self.group_size = group_size - - # Verify - if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: - raise ValueError( - f"Marlin_24 does not support quant_type = {quant_type}. " - f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " - "are supported." - ) - if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Marlin_24 does not support group_size = {self.group_size}. " - f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " - "are supported." - ) - - self.quant_type = quant_type - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.quant_type.size_bits - - # Tile size used by marlin kernels. - self.tile_size = 16 - - # Min out_features dim - self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N - - # Min in_features dim - self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL - - # Permutation length used by the marlin kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return "Marlin24Config(quant_type={}, group_size={})".format( - self.quant_type, self.group_size - ) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "gptq_marlin_24" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config": - weight_bits = cls.get_from_keys(config, ["bits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - return cls(weight_bits, group_size) - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant - ) -> QuantizationMethods | None: - is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24" - - is_valid_user_quant = ( - user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24" - ) - - if is_marlin_24_format and is_valid_user_quant: - msg = "The model is serialized in {} format. Using {} kernel.".format( - cls.get_name(), cls.get_name() - ) - logger.info(msg) - return cls.get_name() - - return None - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["GPTQMarlin24LinearMethod"]: - if isinstance(layer, LinearBase): - return GPTQMarlin24LinearMethod(self) - return None - - -class GPTQMarlin24LinearMethod(LinearMethodBase): - """Linear method for Marlin24. - - Args: - quant_config: The Marlin24 quantization config. - """ - - def __init__(self, quant_config: GPTQMarlin24Config): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - del output_size # Unused. - weight_loader = extra_weight_attrs["weight_loader"] - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}" - ) - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}." - ) - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}." - ) - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}." - ) - if ( - self.quant_config.group_size != -1 - and input_size_per_partition % self.quant_config.group_size != 0 - ): - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}." - ) - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2 - ) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError("Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size // 2, - output_size_per_partition - * self.quant_config.tile_size - // self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader, - ) - - # Meta - meta = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - device="cuda", - dtype=torch.int16, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=1, - marlin_tile_size=2, - weight_loader=weight_loader, - ) - - # Determine if channelwise or not - input_groups = ( - 1 - if self.quant_config.group_size == -1 - else input_size_per_partition // self.quant_config.group_size - ) - - weight_scale_args = { - "data": torch.empty( - input_groups, - output_size_per_partition, - device="cuda", - dtype=params_dtype, - ), - "weight_loader": weight_loader, - } - if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) - else: - scales = GroupQuantScaleParameter( - output_dim=1, input_dim=0, **weight_scale_args - ) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // self.quant_config.min_n_threads - ) * self.quant_config.max_parallel - - workspace = BasevLLMParameter( - data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), - weight_loader=weight_loader, - ) - - layer.register_parameter("B_24", qweight) - layer.register_parameter("B_meta", meta) - layer.register_parameter("s", scales) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B_24 = Parameter(layer.B_24.data, requires_grad=False) - layer.s = Parameter(layer.s.data, requires_grad=False) - layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - qweight = layer.B_24 - meta = layer.B_meta - scales = layer.s - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.gptq_marlin_24_gemm( - x_2d, - qweight, - meta, - scales, - workspace, - self.quant_config.quant_type, - size_m, - size_n, - size_k, - ) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 90011f116..000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,467 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utility functions used for tests and benchmarks""" - -import random - -import numpy -import torch - -from vllm.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into {M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: list[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: list[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list