[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
@@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
|
||||
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||
|
||||
states = self.activation(self.ln[head_index](states)) # b k d
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
previous_hidden_states = states
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
states = states.flatten(0, 1)
|
||||
|
||||
logits = self.logits_processor(self.head[head_index], states,
|
||||
sampling_metadata)
|
||||
|
||||
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
|
||||
output = self.sampler(logits, sampling_metadata)
|
||||
last_tokens = output.sampled_token_ids
|
||||
next_tokens.append(output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user