[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
This commit is contained in:
@@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user