[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
@@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user