[Bugfix] Fix TP > 1 for new granite (#8544)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-09-17 17:17:08 -06:00
committed by GitHub
parent 56c3de018c
commit 98f9713399

View File

@@ -428,7 +428,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits /= self.config.logits_scaling
if logits is not None:
logits /= self.config.logits_scaling
return logits
def sample(