[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

This commit is contained in:
Cyrus Leung
2024-08-13 13:33:41 +08:00
committed by GitHub
parent 5469146bcc
commit 7025b11d94
59 changed files with 411 additions and 202 deletions

View File

@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return logits
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.