[Model] Add tuned triton fused_moe configs for Qwen3Moe (#17328)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Some checks failed
Create Release / Create Release (push) Has been cancelled
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -527,7 +527,7 @@ def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
block_quant_shape = None
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
@@ -546,8 +546,9 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
elif config.architectures[0] in [
|
||||
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||
]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
@@ -565,6 +566,7 @@ def main(args: argparse.Namespace):
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
|
||||
Reference in New Issue
Block a user