[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson
2024-08-08 23:08:46 -06:00
committed by GitHub
parent e02ac55617
commit 99b4cf5f23
4 changed files with 66 additions and 5 deletions

View File

@@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)