[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson
2024-08-08 23:08:46 -06:00
committed by GitHub
parent e02ac55617
commit 99b4cf5f23
4 changed files with 66 additions and 5 deletions

View File

@@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
logits = logits[..., :self.org_vocab_size]
return logits
def extra_repr(self) -> str: