[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

This commit is contained in:
Cyrus Leung
2024-08-13 13:33:41 +08:00
committed by GitHub
parent 5469146bcc
commit 7025b11d94
59 changed files with 411 additions and 202 deletions

View File

@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens