[CORE] Quantized lm-head Framework (#4442)

Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium-ModelCloud
2024-07-03 06:25:17 +08:00
committed by GitHub
parent 7c008c51a9
commit ee93f4f92a
48 changed files with 268 additions and 121 deletions

View File

@@ -242,7 +242,7 @@ class Starcoder2ForCausalLM(nn.Module):
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
@@ -250,8 +250,8 @@ class Starcoder2ForCausalLM(nn.Module):
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
@@ -270,7 +270,7 @@ class Starcoder2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits