[Perf] Add tuned triton moe config for Qwen3.5 H200, 9.9% E2E throughput improvement (#37340)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-18 14:18:34 -04:00
committed by GitHub
parent 5dd8df0701
commit 0ef7f79054
2 changed files with 180 additions and 9 deletions

View File

@@ -750,17 +750,20 @@ def get_weight_block_size_safety(config, default_value=None):
def get_model_params(config):
if config.architectures[0] == "DbrxForCausalLM":
architectures = getattr(config, "architectures", None) or [type(config).__name__]
architecture = architectures[0]
if architecture == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM":
elif architecture == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in (
elif architecture in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
@@ -774,7 +777,7 @@ def get_model_params(config):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in (
elif architecture in (
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
@@ -783,23 +786,27 @@ def get_model_params(config):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
elif architecture in (
"Qwen3VLMoeForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
"Qwen3_5MoeTextConfig",
):
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] == "HunYuanMoEV1ForCausalLM":
elif architecture == "HunYuanMoEV1ForCausalLM":
E = config.num_experts
topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration":
elif architecture == "Qwen3OmniMoeForConditionalGeneration":
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
elif config.architectures[0] == "PixtralForConditionalGeneration":
elif architecture == "PixtralForConditionalGeneration":
# Pixtral can contain different LLM architectures,
# recurse to get their parameters
return get_model_params(config.get_text_config())
@@ -814,6 +821,23 @@ def get_model_params(config):
return E, topk, intermediate_size, hidden_size
def resolve_dtype(config) -> torch.dtype:
if current_platform.is_rocm():
return torch.float16
dtype = getattr(config, "dtype", None)
if dtype is not None:
return dtype
if hasattr(config, "get_text_config"):
text_config = config.get_text_config()
dtype = getattr(text_config, "dtype", None)
if dtype is not None:
return dtype
return torch.bfloat16
def get_quantization_group_size(config) -> int | None:
"""Extract the quantization group size from the HF model config.
@@ -861,7 +885,7 @@ def main(args: argparse.Namespace):
else:
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
dtype = resolve_dtype(config)
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
use_int4_w4a16 = args.dtype == "int4_w4a16"

View File

@@ -0,0 +1,147 @@
{
"triton_version": "3.6.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}