[Chore][Spec Decode] Update check NoneType instead of assigning variables (#18836)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -146,16 +146,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
# Set up speculative decoding.
|
|
||||||
self.use_spec_decode = False
|
|
||||||
self.use_aux_hidden_state_outputs = False
|
self.use_aux_hidden_state_outputs = False
|
||||||
if self.speculative_config:
|
# Set up speculative decoding.
|
||||||
self.use_spec_decode = True
|
|
||||||
|
|
||||||
# NOTE(Jiayi): currently we put the entire draft model on
|
# NOTE(Jiayi): currently we put the entire draft model on
|
||||||
# the last PP rank. This is not ideal if there are many
|
# the last PP rank. This is not ideal if there are many
|
||||||
# layers in the draft model.
|
# layers in the draft model.
|
||||||
if get_pp_group().is_last_rank:
|
if self.speculative_config and get_pp_group().is_last_rank:
|
||||||
if self.speculative_config.method == "ngram":
|
if self.speculative_config.method == "ngram":
|
||||||
self.drafter = NgramProposer(self.vllm_config)
|
self.drafter = NgramProposer(self.vllm_config)
|
||||||
elif self.speculative_config.use_eagle():
|
elif self.speculative_config.use_eagle():
|
||||||
@@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for i in discard_sampled_tokens_req_indices:
|
for i in discard_sampled_tokens_req_indices:
|
||||||
valid_sampled_token_ids[i].clear()
|
valid_sampled_token_ids[i].clear()
|
||||||
|
|
||||||
if not self.use_spec_decode:
|
if not self.speculative_config:
|
||||||
# Speculative decoding is not enabled.
|
# Speculative decoding is not enabled.
|
||||||
spec_token_ids = None
|
spec_token_ids = None
|
||||||
elif self.speculative_config.method == "ngram":
|
elif self.speculative_config.method == "ngram":
|
||||||
@@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
hidden_states = outputs
|
hidden_states = outputs
|
||||||
|
|
||||||
if self.use_spec_decode and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
self.drafter.dummy_run(num_tokens)
|
self.drafter.dummy_run(num_tokens)
|
||||||
|
|
||||||
@@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"initializing the engine.") from e
|
"initializing the engine.") from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
if self.use_spec_decode:
|
if self.speculative_config:
|
||||||
draft_token_ids = [[0] for _ in range(num_reqs)]
|
draft_token_ids = [[0] for _ in range(num_reqs)]
|
||||||
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||||
draft_token_ids, self.device)
|
draft_token_ids, self.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user