[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
This commit is contained in:
@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head = self.model.embed_tokens
|
||||
|
||||
Reference in New Issue
Block a user