[CORE] Quantized lm-head Framework (#4442)

Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium-ModelCloud
2024-07-03 06:25:17 +08:00
committed by GitHub
parent 7c008c51a9
commit ee93f4f92a
48 changed files with 268 additions and 121 deletions

View File

@@ -6,6 +6,8 @@ import torch
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -40,7 +42,7 @@ class LogitsProcessor(nn.Module):
def forward(
self,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
@@ -52,8 +54,7 @@ class LogitsProcessor(nn.Module):
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
@@ -68,12 +69,13 @@ class LogitsProcessor(nn.Module):
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None: