[Spec Decode] Unified Parallel Drafting (#32887)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-05 12:37:18 -05:00
committed by GitHub
parent 5b2a9422f0
commit af3162d3aa
14 changed files with 1085 additions and 392 deletions

View File

@@ -75,6 +75,7 @@ def parse_args():
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="")
return parser.parse_args()
@@ -121,6 +122,7 @@ def main(args):
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
"parallel_drafting": args.parallel_drafting,
}
elif args.method == "ngram":
speculative_config = {
@@ -137,6 +139,7 @@ def main(args):
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
"parallel_drafting": args.parallel_drafting,
}
elif args.method == "mtp":
speculative_config = {