[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
assert self.speculative_config.method == "ngram", \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = NgramProposer()
|
||||
@@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# This usually takes less than 1 second.
|
||||
self.drafter.propose(
|
||||
np.zeros(1024, dtype=np.int32),
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
@@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
|
||||
Reference in New Issue
Block a user