Migrate logits computation and gather to model_runner (#3233)

This commit is contained in:
Roy
2024-03-21 07:25:01 +08:00
committed by GitHub
parent 6e435de766
commit f1c0fc3919
35 changed files with 576 additions and 305 deletions

View File

@@ -0,0 +1,106 @@
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
from vllm.utils import is_neuron
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits