[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

This commit is contained in:
Cyrus Leung
2024-08-13 13:33:41 +08:00
committed by GitHub
parent 5469146bcc
commit 7025b11d94
59 changed files with 411 additions and 202 deletions

View File

@@ -65,22 +65,28 @@ class Medusa(nn.Module):
def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits = []
logits_lst: List[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
if _logits is None:
# _logits should only be None on rank > 0, in which case
# it should remain true for every lm_head
assert len(logits_lst) == 0
continue
if self.token_map is None:
logits.append(_logits)
logits_lst.append(_logits)
else:
logits.append(-torch.inf * torch.ones(
logits_lst.append(-torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype))
logits[-1][..., self.token_map] = _logits
logits_lst[-1][..., self.token_map] = _logits
return logits
return logits_lst
def sample(
self,