[Misc] refactor argument parsing in examples (#16635)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid
2025-04-15 16:05:30 +08:00
committed by GitHub
parent b590adfdc1
commit 6ae996a873
25 changed files with 595 additions and 411 deletions

View File

@@ -34,6 +34,40 @@ from vllm import LLM, SamplingParams
from vllm.utils import get_open_port
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Data Parallel Inference")
parser.add_argument("--model",
type=str,
default="ibm-research/PowerMoE-3b",
help="Model name or path")
parser.add_argument("--dp-size",
type=int,
default=2,
help="Data parallel size")
parser.add_argument("--tp-size",
type=int,
default=2,
help="Tensor parallel size")
parser.add_argument("--node-size",
type=int,
default=1,
help="Total number of nodes")
parser.add_argument("--node-rank",
type=int,
default=0,
help="Rank of the current node")
parser.add_argument("--master-addr",
type=str,
default="",
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
@@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Data Parallel Inference")
parser.add_argument("--model",
type=str,
default="ibm-research/PowerMoE-3b",
help="Model name or path")
parser.add_argument("--dp-size",
type=int,
default=2,
help="Data parallel size")
parser.add_argument("--tp-size",
type=int,
default=2,
help="Tensor parallel size")
parser.add_argument("--node-size",
type=int,
default=1,
help="Total number of nodes")
parser.add_argument("--node-rank",
type=int,
default=0,
help="Rank of the current node")
parser.add_argument("--master-addr",
type=str,
default="",
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
args = parser.parse_args()
args = parse_args()
dp_size = args.dp_size
tp_size = args.tp_size