[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

This commit is contained in:
Cyrus Leung
2024-08-13 13:33:41 +08:00
committed by GitHub
parent 5469146bcc
commit 7025b11d94
59 changed files with 411 additions and 202 deletions

View File

@@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits