[Sparse24] [Deprecation] Remove Sparse24 CT integration and kernels (#36799)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -1,12 +0,0 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
|
||||
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.6353
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.637
|
||||
limit: null
|
||||
num_fewshot: null
|
||||
@@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||
"csrc/quantization/w8a8/int8/per_token_group_quant.cu")
|
||||
@@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
|
||||
@@ -1,517 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import make_rand_sparse_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(
|
||||
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(*args, **kwargs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(
|
||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||
) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.int8
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
||||
)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
print("Incorrect results")
|
||||
print(out)
|
||||
print(out_ref)
|
||||
else:
|
||||
print("Correct results")
|
||||
|
||||
timers = []
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm,
|
||||
a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16),
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales",
|
||||
torch.mm,
|
||||
a.to(dtype=torch.float16),
|
||||
b.to(dtype=torch.float16),
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
bias,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass sparse impl
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass sparse with bias
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
bias,
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(
|
||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||
) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
||||
)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
print("Incorrect results")
|
||||
print(out)
|
||||
print(out_ref)
|
||||
else:
|
||||
print("Correct results")
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm,
|
||||
a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"),
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.float16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
bias,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.float16,
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench(
|
||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||
) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be one of torch.int8, torch.float8_e4m3fn."
|
||||
)
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(
|
||||
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
|
||||
) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(
|
||||
data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
# pickle all the results
|
||||
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||
n = len(dim_sizes)
|
||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KNs.append(KN)
|
||||
return KNs
|
||||
|
||||
model_bench_data = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
Ms = args.batch_sizes
|
||||
KNs = model_shapes(model, tp_size)
|
||||
MKNs = []
|
||||
for m in Ms:
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestamp = int(time.time())
|
||||
|
||||
all_data = []
|
||||
for d in model_bench_data:
|
||||
all_data.extend(d)
|
||||
# pickle all data
|
||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Benchmark Cutlass GEMM.
|
||||
|
||||
To run square GEMMs:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||
|
||||
To run constant N and K and sweep M:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||
|
||||
To run dimensions from a model:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
@@ -5,8 +5,6 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
@@ -39,49 +37,3 @@ def make_rand_tensors(
|
||||
return to_fp8(a), to_fp8(b)
|
||||
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device="cuda") * 5
|
||||
b = torch.randn((n, k), device="cuda").t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
10
csrc/ops.h
10
csrc/ops.h
@@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename Gemm>
|
||||
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
using GemmKernel = typename Gemm::KernelType;
|
||||
using ElementA = typename Gemm::ElementAB;
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
using ProblemShape = typename GemmKernel::ProblemShape;
|
||||
ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
|
||||
using CompressorUtility = typename Gemm::CompressorUtility;
|
||||
CompressorUtility compressor_utility(prob_shape, a_stride);
|
||||
|
||||
// Allocate buffers for the metadata E and the compressed matrix A
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int MC = compressor_utility.get_tensorA_m_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto const a_meta_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto const a_nzs_options =
|
||||
torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
|
||||
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
|
||||
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
using Compressor = typename Gemm::Compressor;
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return {a_meta, a_nzs};
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,307 +0,0 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
struct GemmCallerTraits {
|
||||
using return_type = void;
|
||||
|
||||
template <typename GemmConfig, typename... Args>
|
||||
static return_type invoke(Args&&... args) {
|
||||
return cutlass_sparse_gemm_caller<GemmConfig>(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
struct GemmCompressorTraits {
|
||||
using return_type = CompressorResult;
|
||||
|
||||
template <typename GemmConfig, typename... Args>
|
||||
static return_type invoke(Args&&... args) {
|
||||
return cutlass_sparse_compress<GemmConfig>(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename DispatchFunc, typename... Args>
|
||||
typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch(
|
||||
uint32_t m, uint32_t n, Args&&... args) {
|
||||
static_assert(std::is_same_v<InType, cutlass::float_e4m3_t>);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM256 =
|
||||
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM512 =
|
||||
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
using Cutlass3xGemm1 =
|
||||
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm2 =
|
||||
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm3 =
|
||||
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm4 =
|
||||
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm5 =
|
||||
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm6 =
|
||||
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm7 =
|
||||
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm8 =
|
||||
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
if (n == 28672) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm2>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 4096 || n == 6144) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm1>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 128) {
|
||||
if (n == 4096) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm3>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm5>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm4>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 256) {
|
||||
if (n == 4096) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm6>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm8>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm7>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else {
|
||||
if (n == 6144 || n == 28672) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm8>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 4096) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm7>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise the default heuristic
|
||||
if (mp2 <= 64) {
|
||||
// n in [1, 64]
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM64>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// n in (64, 128]
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM128>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// n in (128, 256]
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM256>(
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
// n in (256, inf)
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM512>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename DispatchFunc, typename... Args>
|
||||
typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch(
|
||||
uint32_t m, uint32_t n, Args&&... args) {
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename DispatchFunc, typename... Args>
|
||||
typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch(
|
||||
uint32_t m, uint32_t n, Args&&... args) {
|
||||
static_assert(std::is_same_v<InType, int8_t>);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
bool const is_small_n = n < 8192;
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM32NSmall>(
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM32NBig>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM64>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM128>(
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch to GEMM implementations based on element types
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
uint32_t const m = out.size(0);
|
||||
uint32_t const n = out.size(1);
|
||||
|
||||
// TODO: add dispatch functions to all of these
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue,
|
||||
GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue,
|
||||
GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::half_t, Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
|
||||
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else { // a.dtype() == torch::kBFloat16
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(out.dtype() == torch::kBFloat16);
|
||||
|
||||
return cutlass_gemm_sm90_16bit_dispatch<
|
||||
cutlass::bfloat16_t, cutlass::bfloat16_t, Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"CUTLASS scaled_mm bias dtype must match output dtype ",
|
||||
out.dtype());
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<
|
||||
c3x::ScaledEpilogueColumnBias>(out, a, bt_nzs, bt_meta, b_scales,
|
||||
a_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales);
|
||||
}
|
||||
}
|
||||
|
||||
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
|
||||
// These m and n variables are fordispatching to different GEMM algorithms.
|
||||
uint32_t const m = 1; // Set M to 1 for compression
|
||||
uint32_t const n = a.size(1);
|
||||
|
||||
// Note: For correctness, the compressed format must be invariant in:
|
||||
// - M, the flattened number of tokens
|
||||
// - Whether output dtype is fp16 or bf16
|
||||
// - CUTLASS epilogues
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
return cutlass_gemm_sm90_16bit_dispatch<
|
||||
cutlass::bfloat16_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16,
|
||||
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
|
||||
"and bf16 datatypes");
|
||||
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,570 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This file defines 2:4 sparse GEMM operations using the CUTLASS 3.x API,
|
||||
for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
|
||||
|
||||
/*
|
||||
* cutlass_sparse_3x_gemm defines a 2:4 sparse GEMM kernel via CUTLASS
|
||||
* for SM90 Hopper systems.
|
||||
*/
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_sparse_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, LayoutC_Transpose, AlignmentCD, ElementD,
|
||||
LayoutC_Transpose, AlignmentCD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
// Sparse compressor definitions
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
using LayoutTagA = cutlass::layout::RowMajor;
|
||||
using CompressorUtility =
|
||||
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
|
||||
SparseConfig>;
|
||||
using CompressorKernel =
|
||||
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
|
||||
SparseConfig, cutlass::arch::Sm90>;
|
||||
using Compressor =
|
||||
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
};
|
||||
|
||||
/*
|
||||
* This class defines kernel to compress a 2:4 sparse matrix.
|
||||
* The particular format is defined by the Gemm template parameter,
|
||||
* which is a cutlass_sparse_3x_gemm.
|
||||
*/
|
||||
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename Gemm>
|
||||
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
using GemmKernel = typename Gemm::KernelType;
|
||||
using ElementA = typename Gemm::ElementAB;
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
using ProblemShape = typename GemmKernel::ProblemShape;
|
||||
ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
|
||||
using CompressorUtility = typename Gemm::CompressorUtility;
|
||||
CompressorUtility compressor_utility(prob_shape, a_stride);
|
||||
|
||||
// Allocate buffers for the metadata E and the compressed matrix A
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int MC = compressor_utility.get_tensorA_m_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto const a_meta_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto const a_nzs_options =
|
||||
torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
|
||||
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
|
||||
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
using Compressor = typename Gemm::Compressor;
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return {a_meta, a_nzs};
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
// Interface stride expected from the argument a (will get transposed)
|
||||
// We compute C^T = B^T * A^T, but we assume B is transposed before
|
||||
// compression and hence the bt_* naming
|
||||
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
|
||||
// M, N, K after transposition
|
||||
int32_t m = out.size(1);
|
||||
int32_t n = out.size(0);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = Stride<Int<1>, int64_t, int64_t>;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideC c_stride{Int<1>{}, ldc, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
LayoutB b_layout = SparseConfig::fill_layoutA(prob_shape);
|
||||
LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape);
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
|
||||
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
b_ptr, b_layout, a_ptr, a_stride, e_ptr, e_layout};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// Gemm Configs are defined below
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default {};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<half_t, OutType, Epilogue> {
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
//////////////////////// Cherry-Picking Kernels ////////////////////////
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_1 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_2 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _64, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_3 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_4 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_5 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_6 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_7 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_8 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _256, _128>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
|
||||
TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M512 {
|
||||
// M in (256, ]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<int8_t, OutType, Epilogue> {
|
||||
// For M > 128 and any N
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -1,104 +0,0 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
|
||||
// sparse CUTLASS kernels need exactly hopper and are not forward compatible
|
||||
// CUDA 12.2 and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
return CUDA_VERSION >= 12020 && cuda_device_capability == 90;
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
|
||||
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
||||
a.size(0) == c.size(0));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
|
||||
c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
// We build for 9.0a which is not forward compatible, so restrict this to
|
||||
// Hopper only
|
||||
if (version_num == 90) {
|
||||
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
|
||||
bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
// We build for 9.0a which is not forward compatible, so restrict this to
|
||||
// Hopper only
|
||||
if (version_num == 90) {
|
||||
std::vector<torch::Tensor> result_tensors;
|
||||
|
||||
auto [a_meta, a_nzs] = cutlass_sparse_compress_sm90(a);
|
||||
result_tensors.push_back(std::move(a_nzs));
|
||||
result_tensors.push_back(std::move(a_meta));
|
||||
return result_tensors;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_sparse_compress for a compute capability equal to "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@@ -523,26 +523,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
&cutlass_scaled_mm_supports_block_fp8);
|
||||
|
||||
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
||||
// given capability
|
||||
ops.def(
|
||||
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_sparse_scaled_mm_supported",
|
||||
&cutlass_sparse_scaled_mm_supported);
|
||||
|
||||
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||
" Tensor bt_nzs,"
|
||||
" Tensor bt_meta, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||
|
||||
// CUTLASS sparse matrix compressor
|
||||
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
|
||||
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
# This function checks that applying an identity matrix multiplication
|
||||
# to the compressed weights yields the original uncompressed weights.
|
||||
def check_compress_decompress_invariance(
|
||||
dtype: torch.dtype,
|
||||
b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor,
|
||||
):
|
||||
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
||||
# same dtype as its inputs. This line addresses that constraint while
|
||||
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
||||
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
||||
|
||||
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
|
||||
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(
|
||||
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device="cuda")
|
||||
b = torch.randn((n, k), device="cuda").t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
# ensure A and B aren't all zeros after rounding
|
||||
a = a * 5.0
|
||||
b = b * 5.0
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
check_compress_decompress_invariance(dtype, b, b_compressed, e)
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||
)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 512),
|
||||
(16, 256, 512),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 512),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 512),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(
|
||||
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_int8_gemm(
|
||||
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
@@ -12,7 +12,6 @@ from compressed_tensors.quantization import QuantizationType
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensors24,
|
||||
CompressedTensorsLinearMethod,
|
||||
CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A8Fp8,
|
||||
@@ -27,9 +26,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
|
||||
@@ -362,283 +358,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="dense"):
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
|
||||
assert qkv_proj.scheme.input_quant.strategy == input_strategy
|
||||
assert qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == format
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda() or not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
|
||||
"channel",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
||||
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda() or not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
|
||||
"channel",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
||||
_test_2of4_quant_models(
|
||||
qkv_proj,
|
||||
weight_strategy,
|
||||
input_strategy,
|
||||
format="sparse-24-bitmask",
|
||||
)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="cutlass is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
|
||||
"channel",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.int8
|
||||
_test_2of4_quant_models(
|
||||
qkv_proj,
|
||||
weight_strategy,
|
||||
input_strategy,
|
||||
format="sparse-24-bitmask",
|
||||
)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.int8
|
||||
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="2of4 Sparse is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
|
||||
)
|
||||
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||
model = args_2of4
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
assert qkv_proj.scheme.weight_quant is None
|
||||
assert qkv_proj.scheme.input_quant is None
|
||||
assert not qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "dense"
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Cutlass is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]
|
||||
)
|
||||
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
|
||||
model = args_2of4
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
assert qkv_proj.scheme.weight_quant is None
|
||||
assert qkv_proj.scheme.input_quant is None
|
||||
assert not qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
|
||||
@@ -20,8 +20,6 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
|
||||
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||
|
||||
@@ -876,10 +876,6 @@ def cutlass_scaled_mm_azp(
|
||||
return out.view(*target_shape)
|
||||
|
||||
|
||||
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
|
||||
if cuda_device_capability < 90 or cuda_device_capability >= 110:
|
||||
return False
|
||||
@@ -890,94 +886,6 @@ def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compresses a sparse matrix for use with Cutlass sparse operations.
|
||||
|
||||
This function takes a dense tensor and compresses it into two components:
|
||||
non-zero elements and metadata. The compressed representation is compatible
|
||||
with Cutlass sparse kernels.
|
||||
|
||||
Args:
|
||||
a (torch.Tensor):
|
||||
The input tensor to be compressed. Must have one of the following data types:
|
||||
- `torch.int8`
|
||||
- `torch.float8_e4m3fn`
|
||||
- `torch.bfloat16`
|
||||
- `torch.float16`
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]:
|
||||
A tuple containing:
|
||||
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
|
||||
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
|
||||
|
||||
Raises:
|
||||
ValueError: If the compression operation fails.
|
||||
|
||||
Notes:
|
||||
- The `a_meta` tensor has a data type of `torch.uint8`.
|
||||
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
|
||||
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
|
||||
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
|
||||
"""
|
||||
assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16]
|
||||
assert a.is_contiguous()
|
||||
|
||||
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||
elemsPerMetaElem = 4
|
||||
assert a.shape[1] % (2 * elemsPerMetaElem) == 0
|
||||
|
||||
return torch.ops._C.cutlass_sparse_compress(a)
|
||||
|
||||
|
||||
def cutlass_scaled_sparse_mm(
|
||||
a: torch.Tensor,
|
||||
bt_nzs: torch.Tensor,
|
||||
bt_meta: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs a scaled sparse matrix multiplication using Cutlass.
|
||||
|
||||
Steps:
|
||||
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
|
||||
`a = torch.randn((m, k), device='cuda')`.
|
||||
|
||||
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
|
||||
`b = torch.randn((k, n), device='cuda')`.
|
||||
|
||||
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
|
||||
`b = prune_to_2_4(b, dim=0)`.
|
||||
|
||||
4. Compress the transposed sparse matrix `b.t()`:
|
||||
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
|
||||
|
||||
5. Perform sparse matrix multiplication using the compressed matrix,
|
||||
applying scaling factors for `a` and `b`, and the output data type:
|
||||
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
|
||||
|
||||
Returns:
|
||||
- The result of the scaled sparse matrix multiplication.
|
||||
"""
|
||||
assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0
|
||||
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
||||
assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = bt_nzs.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
torch.ops._C.cutlass_scaled_sparse_mm(
|
||||
out, a, bt_nzs, bt_meta, scale_a, scale_b, bias
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_cutlass_moe_mm_data(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
|
||||
@@ -5,39 +5,16 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType,
|
||||
)
|
||||
from compressed_tensors.utils import combine_shards
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
@@ -47,33 +24,11 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
input_quant: QuantizationArgs | None = None,
|
||||
model_compression_config: dict[str, Any] | None = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
model_compressor = ModelCompressor.from_compression_config(
|
||||
model_compression_config
|
||||
)
|
||||
self.do_sparse_decompress = (
|
||||
model_compressor is not None
|
||||
and model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
self.model_compressor = model_compressor
|
||||
|
||||
if (
|
||||
quantized
|
||||
and input_quant is not None
|
||||
and self._get_quant_dtype() == current_platform.fp8_dtype()
|
||||
):
|
||||
static = not input_quant.dynamic
|
||||
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.quant_fp8 = QuantFP8(static, g_shape)
|
||||
raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -85,164 +40,10 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
"CUDA 12.2 or later to use this feature"
|
||||
)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
assert all(
|
||||
partition_size % 8 == 0 for partition_size in output_partition_sizes
|
||||
), "All partitions must be divisible by 8 for "
|
||||
"2:4 sparse compressed models"
|
||||
|
||||
shape = BasevLLMParameter(
|
||||
data=torch.empty(2, 1, dtype=torch.int64),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
compressed_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
bitmask = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 8,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("shape", shape)
|
||||
layer.register_parameter("compressed", compressed_weight)
|
||||
layer.register_parameter("bitmask", bitmask)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
if (
|
||||
self.weight_quant
|
||||
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
):
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32
|
||||
),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.weight_quant
|
||||
and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# input quant will be non-none
|
||||
if self.input_quant and not self.input_quant.dynamic:
|
||||
# register input quant scale
|
||||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
else:
|
||||
# for sparse-only, pass in 1 for weight/input scales
|
||||
weight_scale = torch.nn.Parameter(
|
||||
data=torch.ones(1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
input_scale = torch.nn.Parameter(
|
||||
data=torch.ones(1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Compress weights after loading. Store compressed weight and meta
|
||||
tensor
|
||||
|
||||
:post-condition: layer.w_compressed and layer.meta are
|
||||
set to the compressed weight and meta tensor in the
|
||||
format expected by the Cutlass kernels
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
if self.do_sparse_decompress:
|
||||
layer.weight.data = self._decompress_bitmask_compressed_weight(
|
||||
compressed=layer.compressed,
|
||||
bitmask=layer.bitmask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# compressed and bitmask tensors
|
||||
# are no longer needed after decompression
|
||||
del layer.compressed
|
||||
del layer.bitmask
|
||||
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Set all negative zero values to 0 prior to compression
|
||||
if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2:
|
||||
layer.weight.data[layer.weight.data == -0.0] = 0.0
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@@ -250,143 +51,4 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
and bias
|
||||
|
||||
:param layer: The layer with 2:4 sparse compressed
|
||||
weights to be used for the computation
|
||||
:param x: The input tensor to the layer
|
||||
:param bias: The bias to be added to the output tensor
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = getattr(layer, "input_scale", None)
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
q_input = ops_output[0]
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
q_input, input_scale = self.quant_fp8(x, scale=scale)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
return self._get_quant_dtype()
|
||||
|
||||
def _get_quant_dtype(self) -> torch.dtype:
|
||||
assert self.quantized
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||
|
||||
if not is_8_bits:
|
||||
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||
|
||||
if (
|
||||
self.weight_quant.type == QuantizationType.FLOAT
|
||||
and self.input_quant.type == QuantizationType.FLOAT
|
||||
):
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if (
|
||||
self.weight_quant.type == QuantizationType.INT
|
||||
and self.input_quant.type == QuantizationType.INT
|
||||
):
|
||||
return torch.int8
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
def _decompress_bitmask_compressed_weight(
|
||||
self,
|
||||
compressed: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
|
||||
return the result.
|
||||
|
||||
This function also supports sharded decompression.
|
||||
|
||||
:param compressed: The 2:4 sparse weight tensor compressed using the
|
||||
sparse-24-bitmask compressor. This is different from
|
||||
`cutlass_sparse_compress` which uses a different scheme (2 bits for
|
||||
every nonzero element that represent the coordinate within the block
|
||||
of 4). The bitmask compression here uses a bitmask to indicate the
|
||||
positions of non-zero elements.
|
||||
:param bitmask: The 2:4 bitmask associated with the compressed weights,
|
||||
representing the positions of non-zero elements in the compressed
|
||||
tensor.
|
||||
:param layer: The layer whose weights need to be processed after
|
||||
loading.
|
||||
:return: The decompressed 2:4 sparse weight tensor.
|
||||
"""
|
||||
|
||||
sparsity_compressor = self.model_compressor.sparsity_compressor
|
||||
|
||||
def _process_split(
|
||||
bitmask_compressed_weight: torch.Tensor,
|
||||
shape,
|
||||
bitmask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
weight_data = dict(
|
||||
compressed=bitmask_compressed_weight,
|
||||
shape=shape,
|
||||
bitmask=bitmask,
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
split_bitmask = torch.split(bitmask, layer.logical_widths)
|
||||
split_shape = [
|
||||
(out, layer.input_size_per_partition) for out in layer.logical_widths
|
||||
]
|
||||
|
||||
if split_weights:
|
||||
decompressed_shards = [
|
||||
_process_split(compressed_weight, shape, bitmask)
|
||||
for compressed_weight, shape, bitmask in zip(
|
||||
split_weights, split_shape, split_bitmask
|
||||
)
|
||||
]
|
||||
decompressed = combine_shards(decompressed_shards)
|
||||
else:
|
||||
decompressed = sparsity_compressor.decompress_weight(
|
||||
dict(
|
||||
compressed=compressed,
|
||||
shape=(
|
||||
layer.logical_widths[0],
|
||||
layer.input_size_per_partition,
|
||||
),
|
||||
bitmask=bitmask,
|
||||
)
|
||||
)
|
||||
return decompressed
|
||||
raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
|
||||
|
||||
@@ -20,7 +20,7 @@ class CompressedTensorsScheme(ABC):
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
@@ -28,7 +28,7 @@ class CompressedTensorsScheme(ABC):
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
@@ -44,7 +44,7 @@ class CompressedTensorsScheme(ABC):
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
@@ -52,4 +52,4 @@ class CompressedTensorsScheme(ABC):
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -8,16 +8,6 @@ from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def sparse_cutlass_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
|
||||
return ops.cutlass_sparse_scaled_mm_supported(capability)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user