[CORE] Quantized lm-head Framework (#4442)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
committed by
GitHub
parent
7c008c51a9
commit
ee93f4f92a
@@ -316,11 +316,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@@ -339,7 +339,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user