[CORE] Quantized lm-head Framework (#4442)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
committed by
GitHub
parent
7c008c51a9
commit
ee93f4f92a
@@ -363,12 +363,12 @@ class CohereForCausalLM(nn.Module):
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
||||
if is_not_lora:
|
||||
embedding_weights = self.model.embed_tokens.weight
|
||||
logits = self.logits_processor(self.model.embed_tokens,
|
||||
hidden_states, sampling_metadata)
|
||||
else:
|
||||
embedding_weights = self.model.embed_tokens.base_layer.weight
|
||||
logits = self.logits_processor(self.model.embed_tokens.base_layer,
|
||||
hidden_states, sampling_metadata)
|
||||
|
||||
logits = self.logits_processor(embedding_weights, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
|
||||
Reference in New Issue
Block a user