[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
This commit is contained in:
@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
if self.use_gather:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
else:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
|
||||
Reference in New Issue
Block a user