[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: Tuple[int, int]
|
||||
cluster_shape_mnk: Tuple[int, int, int]
|
||||
@@ -328,56 +328,137 @@ def generate():
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mnk=cluster_shape_mnk,
|
||||
kernel_schedule=kernel_schedule,
|
||||
epilogue_schedule=epilogue_schedule,
|
||||
tile_scheduler=tile_scheduler,
|
||||
) for tile_shape_mn, cluster_shape_mnk in (
|
||||
((128, 16), (1, 1, 1)),
|
||||
((128, 32), (1, 1, 1)),
|
||||
((128, 64), (1, 1, 1)),
|
||||
((128, 128), (1, 1, 1)),
|
||||
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
|
||||
for tile_scheduler in (TileSchedulerType.StreamK, )
|
||||
]
|
||||
schedule_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
("M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
(None,
|
||||
ScheduleConfig(tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK))
|
||||
#### M = 257+
|
||||
(
|
||||
"M > 256 && K <= 16384 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 256",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 129-256
|
||||
(
|
||||
"M > 128 && K <= 4096 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128 && K <= 8192 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 65-128
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 4069",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K >= 8192 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 33-64
|
||||
(
|
||||
"M > 32 && K <= 6144 && N <= 6144",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32 && K >= 16384 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 17-32
|
||||
(
|
||||
"M > 16 && K <= 12288 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 1-16
|
||||
(
|
||||
"N >= 26624",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
None,
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
]
|
||||
|
||||
schedules = list(set([x[1] for x in default_heuristic]))
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
|
||||
Reference in New Issue
Block a user