Align vLLM's beam search implementation with HF generate (#857)

This commit is contained in:
Zhuohan Li
2023-09-04 17:29:42 -07:00
committed by GitHub
parent e15932bb60
commit 002800f081
24 changed files with 596 additions and 260 deletions

View File

@@ -1,7 +1,7 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
@@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,