[Kernel] Triton Configs for Fp8 Block Quantization (#11589)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: mgoin <michael@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Robert Shaw
2025-01-30 14:53:22 -05:00
committed by GitHub
parent 41bf5612f5
commit 9b0c4bab36
43 changed files with 5972 additions and 42 deletions

View File

@@ -598,15 +598,27 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
)
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
f",block_shape={block_shape}")
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
@functools.lru_cache
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
def get_moe_configs(
E: int,
N: int,
dtype: Optional[str],
block_n: Optional[int] = None,
block_k: Optional[int] = None,
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
@@ -618,7 +630,8 @@ def get_moe_configs(E: int, N: int,
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N, dtype)
block_shape = [block_n, block_k] if block_n and block_k else None
json_file_name = get_config_file_name(E, N, dtype, block_shape)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
@@ -645,21 +658,53 @@ def get_default_config(
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
if dtype == "fp8_w8a8":
if block_shape is None:
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
else:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
else:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
@@ -679,7 +724,9 @@ def try_get_optimal_moe_config(
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)
if configs:
# If an optimal configuration map has been found, look up the
@@ -688,13 +735,7 @@ def try_get_optimal_moe_config(
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin)
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if block_shape is not None and block_shape[0] != 0:
config["BLOCK_SIZE_N"] = block_shape[0]
config["BLOCK_SIZE_K"] = block_shape[1]
is_marlin, block_shape)
return config