[spec decode] Fix MTP inference path for MiMo-7B model (#25136)
Signed-off-by: zixi-qi <qizixi@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -53,7 +53,6 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@@ -118,6 +117,11 @@ def main():
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method.endswith("mtp"):
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user