[CORE] Quantized lm-head Framework (#4442)

Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium-ModelCloud
2024-07-03 06:25:17 +08:00
committed by GitHub
parent 7c008c51a9
commit ee93f4f92a
48 changed files with 268 additions and 121 deletions

View File

@@ -363,12 +363,12 @@ class CohereForCausalLM(nn.Module):
sampling_metadata: SamplingMetadata) -> torch.Tensor:
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
embedding_weights = self.model.embed_tokens.weight
logits = self.logits_processor(self.model.embed_tokens,
hidden_states, sampling_metadata)
else:
embedding_weights = self.model.embed_tokens.base_layer.weight
logits = self.logits_processor(self.model.embed_tokens.base_layer,
hidden_states, sampling_metadata)
logits = self.logits_processor(embedding_weights, hidden_states,
sampling_metadata)
return logits
def sample(