[Spec Decode] Unified Parallel Drafting (#32887)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
5b2a9422f0
commit
af3162d3aa
@@ -75,6 +75,7 @@ def parse_args():
|
||||
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
|
||||
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
|
||||
parser.add_argument("--max-num-seqs", type=int, default=None)
|
||||
parser.add_argument("--parallel-drafting", action="store_true")
|
||||
parser.add_argument("--allowed-local-media-path", type=str, default="")
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -121,6 +122,7 @@ def main(args):
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
|
||||
"parallel_drafting": args.parallel_drafting,
|
||||
}
|
||||
elif args.method == "ngram":
|
||||
speculative_config = {
|
||||
@@ -137,6 +139,7 @@ def main(args):
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"enforce_eager": args.enforce_eager,
|
||||
"max_model_len": args.max_model_len,
|
||||
"parallel_drafting": args.parallel_drafting,
|
||||
}
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
|
||||
Reference in New Issue
Block a user