feat: spec decode with draft models (#24322)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -54,7 +54,7 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@@ -70,7 +70,11 @@ def parse_args():
|
||||
parser.add_argument("--output-len", type=int, default=256)
|
||||
parser.add_argument("--model-dir", type=str, default=None)
|
||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||
parser.add_argument("--draft-model", type=str, default=None)
|
||||
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
|
||||
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
|
||||
parser.add_argument("--max-num-seqs", type=int, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -111,6 +115,7 @@ def main(args):
|
||||
"method": args.method,
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
|
||||
}
|
||||
elif args.method == "ngram":
|
||||
speculative_config = {
|
||||
@@ -119,6 +124,15 @@ def main(args):
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method == "draft_model":
|
||||
assert args.draft_model is not None and args.draft_model != ""
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"model": args.draft_model,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"enforce_eager": args.enforce_eager,
|
||||
"max_model_len": args.max_model_len,
|
||||
}
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
"method": "mtp",
|
||||
@@ -133,12 +147,13 @@ def main(args):
|
||||
tensor_parallel_size=args.tp,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
enforce_eager=args.enforce_eager,
|
||||
gpu_memory_utilization=0.9,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=args.max_model_len,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
|
||||
Reference in New Issue
Block a user