[CORE] Quantized lm-head Framework (#4442)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
committed by
GitHub
parent
7c008c51a9
commit
ee93f4f92a
@@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import MLPSpeculatorConfig
|
||||
@@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module):
|
||||
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
||||
(self.max_speculative_tokens - 1))
|
||||
|
||||
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
||||
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
|
||||
|
||||
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
||||
@@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module):
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
previous_hidden_states = states
|
||||
|
||||
logits = self.logits_processor(self.head[head_index].weight,
|
||||
states, sampling_metadata)
|
||||
logits = self.logits_processor(self.head[head_index], states,
|
||||
sampling_metadata)
|
||||
|
||||
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
|
||||
last_tokens = output.sampled_token_ids
|
||||
|
||||
Reference in New Issue
Block a user