Align vLLM's beam search implementation with HF generate (#857)

This commit is contained in:
Zhuohan Li
2023-09-04 17:29:42 -07:00
committed by GitHub
parent e15932bb60
commit 002800f081
24 changed files with 596 additions and 260 deletions

View File

@@ -22,7 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,