Implement prompt logprobs & Batched topk for computing logprobs (#1328)

Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
Zhuohan Li
2023-10-16 10:56:50 -07:00
committed by GitHub
parent 928de46888
commit 9d9072a069
14 changed files with 369 additions and 130 deletions

View File

@@ -143,7 +143,7 @@ class ModelConfig:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]