[Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel (#7174)
This commit is contained in:
446
csrc/quantization/machete/generate.py
Normal file
446
csrc/quantization/machete/generate.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
MixedInputKernelScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType, VLLMDataType,
|
||||
VLLMDataTypeNames, VLLMDataTypeTag,
|
||||
VLLMKernelScheduleTag)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
#
|
||||
# Generator templating
|
||||
#
|
||||
|
||||
DISPATCH_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
using GemmDispatcher_ = GemmDispatcher<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
||||
|
||||
{% for s in schedules %}extern torch::Tensor
|
||||
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
|
||||
{% endfor %}
|
||||
template <>
|
||||
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
|
||||
[[maybe_unused]] auto M = args.A.size(0);
|
||||
[[maybe_unused]] auto N = args.B.size(1);
|
||||
[[maybe_unused]] auto K = args.A.size(1);
|
||||
|
||||
if (!args.schedule) {
|
||||
{%- for cond, s in heuristic %}
|
||||
{%if cond is not none%}if ({{cond}})
|
||||
{%- else %}else
|
||||
{%- endif %}
|
||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
|
||||
}
|
||||
|
||||
{% for s in schedules %}
|
||||
if (*args.schedule == "{{ gen_sch_name(s) }}") {
|
||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
|
||||
}
|
||||
{% endfor %}
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
||||
"schedule = ", *args.schedule);
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::string> GemmDispatcher_::supported_schedules() {
|
||||
return {
|
||||
{% for s in schedules -%}
|
||||
"{{ gen_sch_name(s) }}"{{ ",
|
||||
" if not loop.last }}{%- endfor %}
|
||||
};
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
IMPL_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
|
||||
using Kernel = MacheteKernelTemplate<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
Config, with_C, with_scales, with_zeropoints>;
|
||||
|
||||
{% for sch in schedules %}
|
||||
{% set schedule_name = gen_sch_name(sch) -%}
|
||||
struct sch_{{schedule_name}} {
|
||||
using TileShapeNM = Shape<{{
|
||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||
using ClusterShape = Shape<{{
|
||||
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
|
||||
// TODO: Reimplement
|
||||
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
|
||||
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
|
||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
};
|
||||
|
||||
torch::Tensor
|
||||
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
|
||||
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
|
||||
with_zeropoints = args.zeros.has_value();
|
||||
|
||||
{% for s in specializations %}
|
||||
if (with_C == {{s.with_C|lower}}
|
||||
&& with_zeropoints == {{s.with_zeropoints|lower}}
|
||||
&& with_scales == {{s.with_scales|lower}}) {
|
||||
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
|
||||
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
|
||||
}{% endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "for the sake of compile times and binary size machete_mm(..) is "
|
||||
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
|
||||
", with_zeropoints=", with_zeropoints,
|
||||
" (for {{type_name}}_sch_{{schedule_name}})");
|
||||
}
|
||||
{% endfor %}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
PREPACK_TEMPLATE = """
|
||||
#include "../machete_prepack_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
using PrepackBDispatcher_ = PrepackBDispatcher<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
||||
|
||||
using PrepackedLayoutB = PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
|
||||
|
||||
template <>
|
||||
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
|
||||
return prepack_impl<PrepackedLayoutB>(B);
|
||||
}
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: Tuple[int, int]
|
||||
cluster_shape_mnk: Tuple[int, int, int]
|
||||
kernel_schedule: MixedInputKernelScheduleType
|
||||
epilogue_schedule: EpilogueScheduleType
|
||||
tile_scheduler: TileSchedulerType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
element_a: DataType
|
||||
element_b: Union[DataType, VLLMDataType]
|
||||
element_b_scale: DataType
|
||||
element_b_zeropoint: DataType
|
||||
element_d: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Specialization:
|
||||
with_C: bool
|
||||
with_zeropoints: bool
|
||||
with_scales: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
type_config: TypeConfig
|
||||
schedule_configs: List[ScheduleConfig]
|
||||
specializations: List[Specialization]
|
||||
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = (
|
||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
)
|
||||
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
|
||||
f"x{schedule_config.cluster_shape_mnk[1]}" +
|
||||
f"x{schedule_config.cluster_shape_mnk[2]}")
|
||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
|
||||
.split("::")[-1]
|
||||
epilogue_schedule = EpilogueScheduleTag[
|
||||
schedule_config.epilogue_schedule].split("::")[-1]
|
||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
|
||||
.split("::")[-1]
|
||||
|
||||
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
|
||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
||||
|
||||
|
||||
# mostly unique shorter schedule_name
|
||||
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
schedule_name = generate_schedule_name(schedule_config)
|
||||
for orig, terse in kernel_terse_names_replace.items():
|
||||
schedule_name = schedule_name.replace(orig, terse)
|
||||
return schedule_name
|
||||
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_type_config: TypeConfig):
|
||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
||||
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
|
||||
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
|
||||
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
|
||||
element_zeropoint = VLLMDataTypeNames[
|
||||
kernel_type_config.element_b_zeropoint]
|
||||
|
||||
return (f"{element_a}{element_b}{element_d}"
|
||||
f"{accumulator}{element_scale}{element_zeropoint}")
|
||||
|
||||
|
||||
# non-unique shorter type_name
|
||||
def generate_terse_type_signature(kernel_type_config: TypeConfig):
|
||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
||||
|
||||
return f"{element_a}{element_b}"
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return (n != 0) and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def to_cute_constant(value: List[int]):
|
||||
|
||||
def _to_cute_constant(value: int):
|
||||
if is_power_of_two(value):
|
||||
return f"_{value}"
|
||||
else:
|
||||
return f"Int<{value}>"
|
||||
|
||||
if isinstance(value, Iterable):
|
||||
return [_to_cute_constant(value) for value in value]
|
||||
else:
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
template_globals = {
|
||||
"DataTypeTag": VLLMDataTypeTag,
|
||||
"KernelScheduleTag": VLLMKernelScheduleTag,
|
||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||
"TileSchedulerTag": TileSchedulerTag,
|
||||
"to_cute_constant": to_cute_constant,
|
||||
"gen_sch_name": generate_terse_schedule_name,
|
||||
}
|
||||
|
||||
|
||||
def create_template(template_str):
|
||||
template = jinja2.Template(template_str)
|
||||
template.globals.update(template_globals)
|
||||
return template
|
||||
|
||||
|
||||
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
|
||||
mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_config: ImplConfig, num_impl_files=2):
|
||||
sources = []
|
||||
|
||||
type_name = generate_type_signature(impl_config.type_config)
|
||||
terse_type_name = generate_terse_type_signature(impl_config.type_config)
|
||||
|
||||
sources.append((
|
||||
f"machete_mm_{terse_type_name}",
|
||||
mm_dispatch_template.render(type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
schedules=impl_config.schedule_configs,
|
||||
heuristic=impl_config.heuristic),
|
||||
))
|
||||
|
||||
sources.append((
|
||||
f"machete_prepack_{terse_type_name}",
|
||||
prepack_dispatch_template.render(
|
||||
type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
),
|
||||
))
|
||||
|
||||
num_schedules = len(impl_config.schedule_configs)
|
||||
schedules_per_file = math.ceil(num_schedules / num_impl_files)
|
||||
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
|
||||
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
|
||||
|
||||
sources.append((
|
||||
f"machete_mm_{terse_type_name}_impl_part{part}",
|
||||
mm_impl_template.render(
|
||||
type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
schedules=file_schedules,
|
||||
specializations=impl_config.specializations,
|
||||
),
|
||||
))
|
||||
return sources
|
||||
|
||||
|
||||
def generate():
|
||||
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mnk=cluster_shape_mnk,
|
||||
kernel_schedule=kernel_schedule,
|
||||
epilogue_schedule=epilogue_schedule,
|
||||
tile_scheduler=tile_scheduler,
|
||||
) for tile_shape_mn, cluster_shape_mnk in (
|
||||
((128, 16), (1, 1, 1)),
|
||||
((128, 32), (1, 1, 1)),
|
||||
((128, 64), (1, 1, 1)),
|
||||
((128, 128), (1, 1, 1)),
|
||||
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
|
||||
for tile_scheduler in (TileSchedulerType.StreamK, )
|
||||
]
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
default_heuristic = [
|
||||
("M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
(None,
|
||||
ScheduleConfig(tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK))
|
||||
]
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
(TypeConfig(
|
||||
element_a=element_a,
|
||||
element_b=element_b,
|
||||
element_b_scale=element_a,
|
||||
element_b_zeropoint=element_a,
|
||||
element_d=element_a,
|
||||
accumulator=DataType.f32,
|
||||
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for element_a in (DataType.f16, DataType.bf16)))
|
||||
|
||||
GPTQ_kernel_specializations = [
|
||||
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
|
||||
]
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2], x[3])
|
||||
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
|
||||
itertools.repeat(GPTQ_kernel_specializations),
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
AWQ_kernel_type_configs = list(
|
||||
(TypeConfig(
|
||||
element_a=element_a,
|
||||
element_b=element_b,
|
||||
element_b_scale=element_a,
|
||||
element_b_zeropoint=element_a,
|
||||
element_d=element_a,
|
||||
accumulator=DataType.f32,
|
||||
) for element_b in (DataType.u4, DataType.u8)
|
||||
for element_a in (DataType.f16, DataType.bf16)))
|
||||
|
||||
AWQ_kernel_specializations = [
|
||||
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
|
||||
]
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2], x[3])
|
||||
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
|
||||
itertools.repeat(AWQ_kernel_specializations),
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
# Delete the "generated" directory if it exists
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Create the "generated" directory
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Render each group of configurations into separate files
|
||||
for impl_config in impl_configs:
|
||||
for filename, code in create_sources(impl_config):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate()
|
||||
Reference in New Issue
Block a user