[Misc] support model prefix & add deepseek vl2 tiny fused moe config (#17763)

Signed-off-by: 唯勤 <xsank.mz@alibaba-inc.com>
Co-authored-by: 唯勤 <xsank.mz@alibaba-inc.com>
This commit is contained in:
xsank
2025-05-08 15:50:22 +08:00
committed by GitHub
parent 39956efb3f
commit 0a9bbaa104
2 changed files with 161 additions and 9 deletions

View File

@@ -6,15 +6,16 @@ import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from types import SimpleNamespace
from typing import Any, TypedDict
import ray
import torch
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
@@ -534,8 +535,12 @@ def get_weight_block_size_safety(config, default_value=None):
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
config = get_config(model=args.model,
trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
@@ -546,15 +551,14 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"):
elif (config.architectures[0]
in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
]:
elif config.architectures[0] in ("Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM"):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
@@ -569,7 +573,8 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else getattr(
torch, config.torch_dtype)
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
@@ -659,6 +664,7 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
args = parser.parse_args()
main(args)