[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
This commit is contained in:
@@ -65,22 +65,28 @@ class Medusa(nn.Module):
|
||||
def compute_logits(
|
||||
self, hidden_states: List[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
|
||||
logits = []
|
||||
logits_lst: List[torch.Tensor] = []
|
||||
|
||||
for hs, lm_head in zip(hidden_states, self.lm_heads):
|
||||
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
|
||||
|
||||
if _logits is None:
|
||||
# _logits should only be None on rank > 0, in which case
|
||||
# it should remain true for every lm_head
|
||||
assert len(logits_lst) == 0
|
||||
continue
|
||||
|
||||
if self.token_map is None:
|
||||
logits.append(_logits)
|
||||
logits_lst.append(_logits)
|
||||
else:
|
||||
logits.append(-torch.inf * torch.ones(
|
||||
logits_lst.append(-torch.inf * torch.ones(
|
||||
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
||||
device=_logits.device,
|
||||
dtype=_logits.dtype))
|
||||
|
||||
logits[-1][..., self.token_map] = _logits
|
||||
logits_lst[-1][..., self.token_map] = _logits
|
||||
|
||||
return logits
|
||||
return logits_lst
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user