[Doc] Explicitly state that PP isn't compatible with speculative decoding yet (#10975)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-08 01:20:49 +08:00
committed by GitHub
parent 39e227c7ae
commit c889d5888b
8 changed files with 32 additions and 9 deletions

View File

@@ -473,10 +473,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

View File

@@ -400,16 +400,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

View File

@@ -540,10 +540,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

View File

@@ -435,9 +435,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

View File

@@ -443,10 +443,11 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)