diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml deleted file mode 100644 index 9a9c74974..000000000 --- a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 -model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.6353 - - name: "exact_match,flexible-extract" - value: 0.637 -limit: null -num_fewshot: null diff --git a/CMakeLists.txt b/CMakeLists.txt index ddc9bcadb..e438ff41d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" - "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" "csrc/quantization/w8a8/int8/per_token_group_quant.cu") @@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # - # 2:4 Sparse Kernels - - # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor - # require CUDA 12.2 or later (and only work on Hopper). - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) - set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") - message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) - message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " - "if you intend on running FP8 sparse quantized models on Hopper.") - else() - message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py deleted file mode 100644 index 7720f15e4..000000000 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ /dev/null @@ -1,517 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import copy -import itertools -import pickle as pkl -import time -from collections.abc import Callable, Iterable - -import torch -import torch.utils.benchmark as TBenchmark -from torch.utils.benchmark import Measurement as TMeasurement -from utils import make_rand_sparse_tensors -from weight_shapes import WEIGHT_SHAPES - -from vllm import _custom_ops as ops -from vllm.utils.argparse_utils import FlexibleArgumentParser - -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] -DEFAULT_TP_SIZES = [1] - - -# bench -def bench_fn( - label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs -) -> TMeasurement: - min_run_time = 1 - - globals = { - "args": args, - "kwargs": kwargs, - "fn": fn, - } - return TBenchmark.Timer( - stmt="fn(*args, **kwargs)", - globals=globals, - label=label, - sub_label=sub_label, - description=description, - ).blocked_autorange(min_run_time=min_run_time) - - -def bench_int8( - dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str -) -> Iterable[TMeasurement]: - assert dtype == torch.int8 - b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - - out = ops.cutlass_scaled_sparse_mm( - a, b_compressed, e, scale_a, scale_b, torch.bfloat16 - ) - out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) - - if not torch.allclose(out, out_ref): - print("Incorrect results") - print(out) - print(out_ref) - else: - print("Correct results") - - timers = [] - # pytorch impl - bfloat16 - timers.append( - bench_fn( - label, - sub_label, - "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, - a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16), - ) - ) - - # pytorch impl - float16 - timers.append( - bench_fn( - label, - sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", - torch.mm, - a.to(dtype=torch.float16), - b.to(dtype=torch.float16), - ) - ) - - # cutlass impl - timers.append( - bench_fn( - label, - sub_label, - "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, - a, - b, - scale_a, - scale_b, - torch.bfloat16, - ) - ) - - # cutlass with bias - timers.append( - bench_fn( - label, - sub_label, - "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, - a, - b, - scale_a, - scale_b, - torch.bfloat16, - bias, - ) - ) - - # cutlass sparse impl - timers.append( - bench_fn( - label, - sub_label, - "cutlass_i8_i8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, - a, - b_compressed, - e, - scale_a, - scale_b, - torch.bfloat16, - ) - ) - - # cutlass sparse with bias - timers.append( - bench_fn( - label, - sub_label, - "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, - a, - b_compressed, - e, - scale_a, - scale_b, - torch.bfloat16, - bias, - ) - ) - - return timers - - -def bench_fp8( - dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str -) -> Iterable[TMeasurement]: - assert dtype == torch.float8_e4m3fn - b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - - out = ops.cutlass_scaled_sparse_mm( - a, b_compressed, e, scale_a, scale_b, torch.bfloat16 - ) - out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) - - if not torch.allclose(out, out_ref): - print("Incorrect results") - print(out) - print(out_ref) - else: - print("Correct results") - - timers = [] - - # pytorch impl w. bf16 - timers.append( - bench_fn( - label, - sub_label, - "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, - a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), - ) - ) - - # pytorch impl: bf16 output, without fp8 fast accum - timers.append( - bench_fn( - label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - ) - ) - - # pytorch impl: bf16 output, with fp8 fast accum - timers.append( - bench_fn( - label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True, - ) - ) - - # pytorch impl: fp16 output, without fp8 fast accum - timers.append( - bench_fn( - label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - ) - ) - - # pytorch impl: fp16 output, with fp8 fast accum - timers.append( - bench_fn( - label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True, - ) - ) - - # cutlass impl: bf16 output - timers.append( - bench_fn( - label, - sub_label, - "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, - a, - b, - scale_a, - scale_b, - torch.bfloat16, - ) - ) - - # cutlass impl: bf16 output - timers.append( - bench_fn( - label, - sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, - a, - b_compressed, - e, - scale_a, - scale_b, - torch.bfloat16, - ) - ) - - # cutlass impl: fp16 output - timers.append( - bench_fn( - label, - sub_label, - "cutlass_fp8_fp8_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, - a, - b_compressed, - e, - scale_a, - scale_b, - torch.float16, - ) - ) - - # cutlass impl: bf16 output, with bias - timers.append( - bench_fn( - label, - sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, - a, - b_compressed, - e, - scale_a, - scale_b, - torch.bfloat16, - bias, - ) - ) - - # cutlass impl: fp16 output, with bias - timers.append( - bench_fn( - label, - sub_label, - "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, - a, - b_compressed, - e, - scale_a, - scale_b, - torch.float16, - bias.to(dtype=torch.float16), - ) - ) - - return timers - - -def bench( - dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str -) -> Iterable[TMeasurement]: - if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) - if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) - raise ValueError( - f"Unsupported dtype {dtype}: should be one of torch.int8, torch.float8_e4m3fn." - ) - - -# runner -def print_timers(timers: Iterable[TMeasurement]): - compare = TBenchmark.Compare(timers) - compare.print() - - -def run( - dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]] -) -> Iterable[TMeasurement]: - results = [] - for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})") - print_timers(timers) - results.extend(timers) - - return results - - -# output makers -def make_output( - data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None, -): - print(f"== All Results {base_description} ====") - print_timers(data) - - # pickle all the results - timestamp = int(time.time()) if timestamp is None else timestamp - with open(f"{base_description}-{timestamp}.pkl", "wb") as f: - pkl.dump(data, f) - - -# argparse runners - - -def run_square_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) - MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, MKNs) - - make_output(data, MKNs, f"square_bench-{args.dtype}") - - -def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) - data = run(args.dtype, MKNs) - - make_output(data, MKNs, f"range_bench-{args.dtype}") - - -def run_model_bench(args): - print("Benchmarking models:") - for i, model in enumerate(args.models): - print(f"[{i}] {model}") - - def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: - KNs = [] - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size - KNs.append(KN) - return KNs - - model_bench_data = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - Ms = args.batch_sizes - KNs = model_shapes(model, tp_size) - MKNs = [] - for m in Ms: - for k, n in KNs: - MKNs.append((m, k, n)) - - data = run(args.dtype, MKNs) - model_bench_data.append(data) - - # Print all results - for data, model_tp in zip(model_bench_data, models_tps): - model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") - print_timers(data) - - timestamp = int(time.time()) - - all_data = [] - for d in model_bench_data: - all_data.extend(d) - # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) - - -if __name__ == "__main__": - - def to_torch_dtype(dt): - if dt == "int8": - return torch.int8 - if dt == "fp8": - return torch.float8_e4m3fn - raise ValueError("unsupported dtype") - - parser = FlexibleArgumentParser( - description=""" -Benchmark Cutlass GEMM. - - To run square GEMMs: - python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 - - To run constant N and K and sweep M: - python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 - - To run dimensions from a model: - python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 - - Output: - - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter, - ) - - parser.add_argument( - "--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']", - ) - subparsers = parser.add_subparsers(dest="cmd") - - square_parser = subparsers.add_parser("square_bench") - square_parser.add_argument("--dim-start", type=int, required=True) - square_parser.add_argument("--dim-end", type=int, required=True) - square_parser.add_argument("--dim-increment", type=int, required=True) - square_parser.set_defaults(func=run_square_bench) - - range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) - range_parser.set_defaults(func=run_range_bench) - - model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument( - "--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys(), - ) - model_parser.add_argument( - "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES - ) - model_parser.add_argument( - "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES - ) - model_parser.set_defaults(func=run_model_bench) - - args = parser.parse_args() - args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index 6cbcf6b68..659c68bb1 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -5,8 +5,6 @@ import torch -import vllm._custom_ops as ops - def to_fp8(tensor: torch.Tensor) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) @@ -39,49 +37,3 @@ def make_rand_tensors( return to_fp8(a), to_fp8(b) raise ValueError("unsupported dtype") - - -def prune_to_2_4(tensor): - # Reshape tensor to [N, 4] where N is number of groups of 4 - original_shape = tensor.shape - reshaped = tensor.reshape(-1, 4) - - # Get indices of top 2 absolute values in each group of 4 - _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) - - # Create binary mask - mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) - - # Apply mask and reshape back - pruned = reshaped * mask - - # Turn all -0.0 to 0.0 - pruned[pruned == -0.0] = 0.0 - - return pruned.reshape(original_shape) - - -def make_rand_sparse_tensors( - dtype: torch.dtype, m: int, n: int, k: int -) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device="cuda") * 5 - b = torch.randn((n, k), device="cuda").t() * 5 - - b = prune_to_2_4(b.t()).t() - - if dtype == torch.int8: - a, b = to_int8(a), to_int8(b) - elif dtype == torch.float8_e4m3fn: - a, b = to_fp8(a), to_fp8(b) - elif dtype == torch.float16: - a, b = to_fp16(a), to_fp16(b) - elif dtype == torch.bfloat16: - a, b = to_bf16(a), to_bf16(b) - else: - raise ValueError("unsupported dtype") - - b_compressed, e = ops.cutlass_sparse_compress(b.t()) - - # Compressed B, Metadata, Original A, B - return b_compressed, e, a, b diff --git a/csrc/ops.h b/csrc/ops.h index 26caf7f7d..ceb8e021c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, std::optional const& azp, std::optional const& bias); -bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); - -void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& e, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - std::optional const& bias); - -std::vector cutlass_sparse_compress(torch::Tensor const& a); - std::tuple scaled_fp4_quant_func( torch::Tensor const& input, torch::Tensor const& input_scale, bool is_sf_swizzled_layout); diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cuh b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh deleted file mode 100644 index 2cc235f3a..000000000 --- a/csrc/sparse/cutlass/sparse_compressor_c3x.cuh +++ /dev/null @@ -1,90 +0,0 @@ -#pragma once - -// clang-format will break include orders -// clang-format off -#include - -#if defined CUDA_VERSION && CUDA_VERSION >= 12020 -#include "sparse_scaled_mm_c3x.cuh" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/transform/device/transform_universal_adapter.hpp" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" - -// clang-format on - -using namespace cute; -using namespace vllm; - -using CompressorResult = std::tuple; -/// Make A structured sparse by replacing elements with 0 and compress it -template -CompressorResult cutlass_sparse_compress(torch::Tensor const& a) { - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || - a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); - TORCH_CHECK(a.dim() == 2) - // Check for strides and alignment - TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity - TORCH_CHECK(a.stride(1) == 1) - - using GemmKernel = typename Gemm::KernelType; - using ElementA = typename Gemm::ElementAB; - using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; - - int m = a.size(0); - int k = a.size(1); - using ProblemShape = typename GemmKernel::ProblemShape; - ProblemShape prob_shape{m, 1, k, 1}; - - int64_t lda = a.stride(0); - using StrideA = Stride, int64_t>; - StrideA a_stride{lda, Int<1>{}, 0}; - - using CompressorUtility = typename Gemm::CompressorUtility; - CompressorUtility compressor_utility(prob_shape, a_stride); - - // Allocate buffers for the metadata E and the compressed matrix A - int ME = compressor_utility.get_metadata_m_physical(); - int KE = compressor_utility.get_metadata_k_physical(); - int MC = compressor_utility.get_tensorA_m_physical(); - int KC = compressor_utility.get_tensorA_k_physical(); - - auto const a_meta_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto const a_nzs_options = - torch::TensorOptions().dtype(a.dtype()).device(a.device()); - - auto a_meta = torch::zeros({ME, KE}, a_meta_options); - auto a_nzs = torch::zeros({MC, KC}, a_nzs_options); - - auto a_ptr = static_cast(a.data_ptr()); - auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); - auto a_meta_ptr = static_cast(a_meta.data_ptr()); - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a.device().index(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - - using Compressor = typename Gemm::Compressor; - typename Compressor::Arguments arguments{ - prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}}; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - CUTLASS_CHECK(compressor_op.can_implement(arguments)); - CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr())); - CUTLASS_CHECK(compressor_op.run()); - CUDA_CHECK(cudaDeviceSynchronize()); - - return {a_meta, a_nzs}; -} - -#endif diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu deleted file mode 100644 index d053ecc8d..000000000 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ /dev/null @@ -1,307 +0,0 @@ -// clang-format will break include orders -// clang-format off -#include - -#if defined CUDA_VERSION && CUDA_VERSION >= 12020 -#include "sparse_scaled_mm_c3x.cuh" -// clang-format on - -using namespace cute; -using namespace vllm; - -struct GemmCallerTraits { - using return_type = void; - - template - static return_type invoke(Args&&... args) { - return cutlass_sparse_gemm_caller(std::forward(args)...); - } -}; - -struct GemmCompressorTraits { - using return_type = CompressorResult; - - template - static return_type invoke(Args&&... args) { - return cutlass_sparse_compress(std::forward(args)...); - } -}; - -template typename Epilogue, - typename DispatchFunc, typename... Args> -typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch( - uint32_t m, uint32_t n, Args&&... args) { - static_assert(std::is_same_v); - - using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - using Cutlass3xGemmM256 = - typename sm90_fp8_config_M256::Cutlass3xGemm; - using Cutlass3xGemmM512 = - typename sm90_fp8_config_M512::Cutlass3xGemm; - - using Cutlass3xGemm1 = - typename sm90_fp8_config_1::Cutlass3xGemm; - using Cutlass3xGemm2 = - typename sm90_fp8_config_2::Cutlass3xGemm; - using Cutlass3xGemm3 = - typename sm90_fp8_config_3::Cutlass3xGemm; - using Cutlass3xGemm4 = - typename sm90_fp8_config_4::Cutlass3xGemm; - using Cutlass3xGemm5 = - typename sm90_fp8_config_5::Cutlass3xGemm; - using Cutlass3xGemm6 = - typename sm90_fp8_config_6::Cutlass3xGemm; - using Cutlass3xGemm7 = - typename sm90_fp8_config_7::Cutlass3xGemm; - using Cutlass3xGemm8 = - typename sm90_fp8_config_8::Cutlass3xGemm; - - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - if (mp2 <= 64) { - if (n == 28672) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (n == 4096 || n == 6144) { - return DispatchFunc::template invoke( - std::forward(args)...); - } - } else if (mp2 <= 128) { - if (n == 4096) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (n == 28672) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (n == 6144) { - return DispatchFunc::template invoke( - std::forward(args)...); - } - } else if (mp2 <= 256) { - if (n == 4096) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (n == 28672) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (n == 6144) { - return DispatchFunc::template invoke( - std::forward(args)...); - } - } else { - if (n == 6144 || n == 28672) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (n == 4096) { - return DispatchFunc::template invoke( - std::forward(args)...); - } - } - - // Otherwise the default heuristic - if (mp2 <= 64) { - // n in [1, 64] - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (mp2 <= 128) { - // n in (64, 128] - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (mp2 <= 256) { - // n in (128, 256] - return DispatchFunc::template invoke( - std::forward(args)...); - } else { - // n in (256, inf) - return DispatchFunc::template invoke( - std::forward(args)...); - } -} - -template typename Epilogue, - typename DispatchFunc, typename... Args> -typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch( - uint32_t m, uint32_t n, Args&&... args) { - using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; - - return DispatchFunc::template invoke( - std::forward(args)...); -} - -template typename Epilogue, - typename DispatchFunc, typename... Args> -typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch( - uint32_t m, uint32_t n, Args&&... args) { - static_assert(std::is_same_v); - - using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_int8_config_M128::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_int8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; - using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - - bool const is_small_n = n < 8192; - uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 - - if (mp2 <= 32) { - // m in [1, 32] - if (is_small_n) { - return DispatchFunc::template invoke( - std::forward(args)...); - } else { - return DispatchFunc::template invoke( - std::forward(args)...); - } - } else if (mp2 <= 64) { - // m in (32, 64] - return DispatchFunc::template invoke( - std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return DispatchFunc::template invoke( - std::forward(args)...); - } else { - // m in (128, inf) - return DispatchFunc::template invoke( - std::forward(args)...); - } -} - -// Dispatch to GEMM implementations based on element types -template