[CORE] Quantized lm-head Framework (#4442)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
committed by
GitHub
parent
7c008c51a9
commit
ee93f4f92a
@@ -6,6 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
|
||||
@@ -40,7 +42,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
@@ -52,8 +54,7 @@ class LogitsProcessor(nn.Module):
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
@@ -68,12 +69,13 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
def _get_logits(self, hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
|
||||
Reference in New Issue
Block a user