[V1][Usage] Refactor speculative decoding configuration and tests (#14434)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-03-23 13:28:10 +08:00
committed by GitHub
parent 0661cfef7a
commit 50c9636d87
20 changed files with 1055 additions and 802 deletions

View File

@@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
@@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
self.rejection_sampler = RejectionSampler()
@@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0: