[Chore][Spec Decode] Update check NoneType instead of assigning variables (#18836)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2025-05-28 14:57:19 -04:00
committed by GitHub
parent 0e98964e94
commit a09c7ca9f2

View File

@@ -146,31 +146,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
if self.speculative_config: # Set up speculative decoding.
self.use_spec_decode = True # NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# NOTE(Jiayi): currently we put the entire draft model on # layers in the draft model.
# the last PP rank. This is not ideal if there are many if self.speculative_config and get_pp_group().is_last_rank:
# layers in the draft model. if self.speculative_config.method == "ngram":
if get_pp_group().is_last_rank: self.drafter = NgramProposer(self.vllm_config)
if self.speculative_config.method == "ngram": elif self.speculative_config.use_eagle():
self.drafter = NgramProposer(self.vllm_config) self.drafter = EagleProposer(self.vllm_config, self.device,
elif self.speculative_config.use_eagle(): self) # type: ignore
self.drafter = EagleProposer(self.vllm_config, self.device, if self.speculative_config.method == "eagle3":
self) # type: ignore self.use_aux_hidden_state_outputs = True
if self.speculative_config.method == "eagle3": elif self.speculative_config.method == "medusa":
self.use_aux_hidden_state_outputs = True self.drafter = MedusaProposer(
elif self.speculative_config.method == "medusa": vllm_config=self.vllm_config,
self.drafter = MedusaProposer( device=self.device) # type: ignore
vllm_config=self.vllm_config, else:
device=self.device) # type: ignore raise ValueError("Unknown speculative decoding method: "
else: f"{self.speculative_config.method}")
raise ValueError("Unknown speculative decoding method: " self.rejection_sampler = RejectionSampler()
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
@@ -1318,7 +1314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if not self.use_spec_decode: if not self.speculative_config:
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
elif self.speculative_config.method == "ngram": elif self.speculative_config.method == "ngram":
@@ -1740,7 +1736,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.use_spec_decode and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
@@ -1795,7 +1791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine.") from e "initializing the engine.") from e
else: else:
raise e raise e
if self.use_spec_decode: if self.speculative_config:
draft_token_ids = [[0] for _ in range(num_reqs)] draft_token_ids = [[0] for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device) draft_token_ids, self.device)