Compare commits

..

61 Commits

Author SHA1 Message Date
Woosuk Kwon
c5f7740d89 Bump up to v0.2.2 (#1689)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.8, 2.1.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.1.0) (push) Has been cancelled
2023-11-18 21:57:07 -08:00
Woosuk Kwon
be66d9b125 Fix warning msg on quantization (#1715) 2023-11-18 21:49:55 -08:00
ljss
e1054247ba [Optimization] Implement fused add rmsnorm (#1667) 2023-11-18 18:18:02 -08:00
Woosuk Kwon
8d17774f92 Add AWQ support for all models (#1714) 2023-11-18 17:56:47 -08:00
twaka
e946260cf3 use get_tensor in safe_open (#1696) 2023-11-18 16:45:18 -08:00
liuyhwangyh
edb305584b Support download models from www.modelscope.cn (#1588) 2023-11-17 20:38:31 -08:00
Woosuk Kwon
bb00f66e19 Use quantization_config in hf config (#1695) 2023-11-17 16:23:49 -08:00
Roy
e87557b069 Support Min P Sampler (#1642) 2023-11-17 16:20:49 -08:00
Zhuofan
dcc543a298 [Minor] Fix comment (#1704) 2023-11-17 09:42:49 -08:00
Zhuohan Li
0fc280b06c Update the adding-model doc according to the new refactor (#1692) 2023-11-16 18:46:26 -08:00
Zhuohan Li
20d0699d49 [Fix] Fix comm test (#1691) 2023-11-16 16:28:39 -08:00
Iskren Ivov Chernev
686f5e3210 Return usage for openai streaming requests (#1663) 2023-11-16 15:28:36 -08:00
Zhuohan Li
415d109527 [Fix] Update Supported Models List (#1690) 2023-11-16 14:47:26 -08:00
maximzubkov
521b35f799 Support Microsoft Phi 1.5 (#1664) 2023-11-16 14:28:39 -08:00
Simon Mo
cb08cd0d75 [Minor] Fix duplication of ignored seq group in engine step (#1666) 2023-11-16 13:11:41 -08:00
twaka
2a2c135b41 Fix loading error when safetensors contains empty tensor (#1687) 2023-11-16 10:38:10 -08:00
Aaron Pham
65ea2ddf17 feat(config): support parsing torch.dtype (#1641)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-16 01:31:06 -08:00
Megha Agarwal
b514d3c496 Revert MptConfig to MPTConfig (#1668) 2023-11-16 01:19:39 -08:00
Zhuohan Li
7076fa1c9f TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)
Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
2023-11-15 22:50:41 -08:00
Woosuk Kwon
660a7fcfa4 Add DeepSpeed MII backend to benchmark script (#1649) 2023-11-14 12:35:30 -08:00
Woosuk Kwon
054072bee5 [Minor] Move RoPE selection logic to get_rope (#1633) 2023-11-12 16:04:50 -08:00
lirui
eb825c1e74 Fix #1474 - AssertionError:assert param_slice.shape == loaded_weight.shape (#1631) 2023-11-12 15:53:12 -08:00
Dominik Schwabe
1b290ace4f Run default _AsyncLLMEngine._run_workers_async in threadpool (#1628) 2023-11-11 14:50:44 -08:00
Sin
0d578228ca config parser: add ChatGLM2 seq_length to _get_and_verify_max_len (#1617) 2023-11-09 19:29:51 -08:00
GhaziSyed
aebfcb262a Dockerfile: Upgrade Cuda to 12.1 (#1609) 2023-11-09 11:49:02 -08:00
forpanyang
ab9e8488d5 Add Yi model to quantization support (#1600) 2023-11-09 11:47:14 -08:00
Woosuk Kwon
fd58b73a40 Build CUDA11.8 wheels for release (#1596) 2023-11-09 03:52:29 -08:00
Yanming W
8efe23f150 Fix input_metadata.selected_token_indices in worker prepare_inputs (#1546) 2023-11-08 14:19:12 -08:00
Zhuohan Li
06458a0b42 Upgrade to CUDA 12 (#1527)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2023-11-08 14:17:49 -08:00
GoHomeToMacDonal
1a2bbc9301 ChatGLM Support (#1261) 2023-11-06 16:09:33 -08:00
Roy
e7f579eb97 Support Yi model (#1567) 2023-11-06 15:26:03 -08:00
Casper
8516999495 Add Quantization and AutoAWQ to docs (#1235) 2023-11-04 22:43:39 -07:00
Antoni Baum
9f669a9a7c Support YaRN models (#1264)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Viktor Ferenczi <viktor@ferenczi.eu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2023-11-03 14:12:48 -07:00
Noam Gat
555bdcc5a3 Added logits processor API to sampling params (#1469) 2023-11-03 14:12:15 -07:00
lots-o
54ca1ba71d docs: add description (#1553) 2023-11-03 09:14:52 -07:00
Antoni Baum
9738b84a08 Force paged attention v2 for long contexts (#1510) 2023-11-01 16:24:32 -07:00
Woosuk Kwon
1fe0990023 Remove MPTConfig (#1529) 2023-11-01 15:29:05 -07:00
Fluder-Paradyne
7e90a2d117 Add /health Endpoint for both Servers (#1540) 2023-11-01 10:29:44 -07:00
ljss
5687d584fe [BugFix] Set engine_use_ray=True when TP>1 (#1531) 2023-11-01 02:14:18 -07:00
Wenfei Yan
cf8849f2d6 Add MptForCausalLM key in model_loader (#1526) 2023-10-31 15:46:53 -07:00
Cade Daniel
e575df33b1 [Small] Formatter only checks lints in changed files (#1528) 2023-10-31 15:39:38 -07:00
Woosuk Kwon
0ce8647dc5 Fix integer overflows in attention & cache ops (#1514) 2023-10-31 15:19:30 -07:00
Stephen Krider
9cabcb7645 Add Dockerfile (#1350) 2023-10-31 12:36:47 -07:00
Zhuohan Li
7b895c5976 [Fix] Fix duplicated logging messages (#1524) 2023-10-31 09:04:47 -07:00
Dan Lord
7013a80170 Add support for spaces_between_special_tokens 2023-10-30 16:52:56 -07:00
Jared Roesch
79a30912b8 Add py.typed so consumers of vLLM can get type checking (#1509)
* Add py.typed so consumers of vLLM can get type checking

* Update py.typed

---------
Co-authored-by: aarnphm <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-10-30 14:50:47 -07:00
Adam Brusselback
2f3d36a8a1 Fix logging so we actually get info level entries in the log. (#1494) 2023-10-30 10:02:21 -07:00
iongpt
ac8d36f3e5 Refactor LLMEngine demo script for clarity and modularity (#1413)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-10-30 09:14:37 -07:00
Antoni Baum
15f5632365 Delay GPU->CPU sync in sampling (#1337) 2023-10-30 09:01:34 -07:00
Woosuk Kwon
aa9af07cac Fix bias in InternLM (#1501) 2023-10-29 16:24:18 -07:00
ljss
69be658bba Support repetition_penalty (#1424) 2023-10-29 10:02:41 -07:00
Ricardo Lu
beac8dd461 fix: don't skip first special token. (#1497) 2023-10-29 04:26:36 -07:00
Qing
28b47d1e49 Add rope_scaling to Aquila model (#1457) 2023-10-29 04:25:21 -07:00
chooper1
1f24755bf8 Support SqueezeLLM (#1326)
Co-authored-by: squeeze-ai-lab <squeezeailab.bair@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2023-10-21 23:14:59 -07:00
Thiago Salvatore
bf31d3606a Pin pydantic dependency versions (#1429) 2023-10-21 11:18:58 -07:00
Wang Ran (汪然)
d189170b6c remove useless statements (#1408) 2023-10-20 08:52:07 -07:00
Light Lin
f61dc8072f Fix type hints (#1427) 2023-10-20 08:50:47 -07:00
Woosuk Kwon
f8a1e39fae [BugFix] Define __eq__ in SequenceGroupOutputs (#1389) 2023-10-17 01:09:44 -07:00
Wang Ran (汪然)
a132435204 Fix typo (#1383) 2023-10-16 21:53:37 -07:00
Woosuk Kwon
9524867701 Add Mistral 7B to test_models (#1366) 2023-10-16 17:49:54 -07:00
Woosuk Kwon
c1376e0f82 Change scheduler & input tensor shape (#1381) 2023-10-16 17:48:42 -07:00
95 changed files with 5029 additions and 2578 deletions

View File

@@ -49,8 +49,8 @@ jobs:
matrix: matrix:
os: ['ubuntu-20.04'] os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11'] python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.0.1'] pytorch-version: ['2.1.0']
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8 cuda-version: ['11.8', '12.1']
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -11,5 +11,8 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
$python_executable -m pip install wheel packaging $python_executable -m pip install wheel packaging
$python_executable -m pip install -r requirements.txt $python_executable -m pip install -r requirements.txt
# Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1
# Build # Build
$python_executable setup.py bdist_wheel --dist-dir=dist $python_executable setup.py bdist_wheel --dist-dir=dist

View File

@@ -16,3 +16,8 @@ sudo apt clean
# Test nvcc # Test nvcc
PATH=/usr/local/cuda-$1/bin:${PATH} PATH=/usr/local/cuda-$1/bin:${PATH}
nvcc --version nvcc --version
# Log gcc, g++, c++ versions
gcc --version
g++ --version
c++ --version

72
Dockerfile Normal file
View File

@@ -0,0 +1,72 @@
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
RUN apt-get update -y \
&& apt-get install -y python3-pip
WORKDIR /workspace
# install build and runtime dependencies
COPY requirements.txt requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt
# install development dependencies
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt
# image to build pytorch extensions
FROM dev AS build
# copy input files
COPY csrc csrc
COPY setup.py setup.py
COPY requirements.txt requirements.txt
COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py
# max jobs used by Ninja to build extensions
ENV MAX_JOBS=$max_jobs
RUN python3 setup.py build_ext --inplace
# image to run unit testing suite
FROM dev AS test
# copy pytorch extensions separately to avoid having to rebuild
# when python code changes
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY tests tests
COPY vllm vllm
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
# use CUDA base as CUDA runtime dependencies are already installed via pip
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
# libnccl required for ray
RUN apt-get update -y \
&& apt-get install -y python3-pip
WORKDIR /workspace
COPY requirements.txt requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt
FROM vllm-base AS vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm
EXPOSE 8000
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
# openai api server alternative
FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate fschat
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

View File

@@ -49,6 +49,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) - Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.) - Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
@@ -59,7 +60,9 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

View File

@@ -70,7 +70,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', None], choices=['awq', 'squeezellm', None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)

View File

@@ -6,18 +6,21 @@ import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None:
if fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
@@ -35,6 +38,8 @@ def sample_requests(
tokenized_dataset = [] tokenized_dataset = []
for i in range(len(dataset)): for i in range(len(dataset)):
output_len = len(completion_token_ids[i]) output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences. # Filter out too long sequences.
@@ -66,6 +71,7 @@ def run_vllm(
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
) -> float: ) -> float:
from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@@ -94,7 +100,7 @@ def run_vllm(
) )
start = time.perf_counter() start = time.perf_counter()
# FIXME(woosuk): Do use internal method. # FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True) llm._run_engine(use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start
@@ -160,14 +166,37 @@ def run_hf(
return end - start return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import pipeline
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm(prompts, max_new_tokens=output_len)
end = time.perf_counter()
return end - start
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, tokenizer = AutoTokenizer.from_pretrained(
trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer, elapsed_time = run_vllm(requests, args.model, args.tokenizer,
@@ -179,6 +208,9 @@ def main(args: argparse.Namespace):
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size, args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code) args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(prompt_len + output_len
@@ -191,17 +223,26 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument("--backend",
type=str, type=str,
choices=["vllm", "hf"], choices=["vllm", "hf", "mii"],
default="vllm") default="vllm")
parser.add_argument("--dataset", parser.add_argument("--dataset",
type=str, type=str,
required=True, default=None,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', None], choices=['awq', 'squeezellm', None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",
@@ -231,6 +272,13 @@ if __name__ == "__main__":
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
if args.backend == "vllm": if args.backend == "vllm":
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
@@ -240,7 +288,18 @@ if __name__ == "__main__":
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None: elif args.backend == "mii":
args.tokenizer = args.model if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
main(args) main(args)

View File

@@ -13,11 +13,11 @@ __device__ __forceinline__ T silu(const T& x) {
template<typename scalar_t> template<typename scalar_t>
__global__ void silu_and_mul_kernel( __global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [num_tokens, d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [num_tokens, 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y; out[token_idx * d + idx] = silu(x) * y;
@@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
} // namespace vllm } // namespace vllm
void silu_and_mul( void silu_and_mul(
torch::Tensor& out, // [num_tokens, d] torch::Tensor& out, // [..., d]
torch::Tensor& input) // [num_tokens, 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.size(0); int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(1) / 2; int d = input.size(-1) / 2;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(d, 1024)); dim3 block(std::min(d, 1024));
@@ -52,11 +52,11 @@ namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [num_tokens, d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]); const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x); out[token_idx * d + idx] = ACT_FN(x);
} }
@@ -66,8 +66,8 @@ __global__ void activation_kernel(
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \ int d = input.size(-1); \
int d = input.size(1); \ int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
@@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(
torch::Tensor& out, // [num_tokens, d] torch::Tensor& out, // [..., d]
torch::Tensor& input) // [num_tokens, d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(
torch::Tensor& out, // [num_tokens, d] torch::Tensor& out, // [..., d]
torch::Tensor& input) // [num_tokens, d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }

View File

@@ -175,7 +175,10 @@ __device__ void paged_attention_kernel(
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx]; // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
@@ -285,7 +288,10 @@ __device__ void paged_attention_kernel(
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx]; // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;

View File

@@ -55,26 +55,26 @@ template<typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(
int64_t* key_cache_ptrs, int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs, int64_t* value_cache_ptrs,
const int* __restrict__ block_mapping, const int64_t* __restrict__ block_mapping,
const int numel_per_block) { const int numel_per_block) {
const int layer_idx = blockIdx.x; const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
const int src_block_offset = src_block_number * numel_per_block; const int64_t src_block_offset = src_block_number * numel_per_block;
const int dst_block_offset = dst_block_number * numel_per_block; const int64_t dst_block_offset = dst_block_number * numel_per_block;
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int src_offset = src_block_offset + i; int64_t src_offset = src_block_offset + i;
int dst_offset = dst_block_offset + i; int64_t dst_offset = dst_block_offset + i;
key_cache[dst_offset] = key_cache[src_offset]; key_cache[dst_offset] = key_cache[src_offset];
} }
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int src_offset = src_block_offset + i; int64_t src_offset = src_block_offset + i;
int dst_offset = dst_block_offset + i; int64_t dst_offset = dst_block_offset + i;
value_cache[dst_offset] = value_cache[src_offset]; value_cache[dst_offset] = value_cache[src_offset];
} }
} }
@@ -102,15 +102,15 @@ void copy_blocks(
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// Create block mapping array. // Create block mapping array.
std::vector<int> block_mapping_vec; std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) { for (const auto& pair : block_mapping) {
int src_block_number = pair.first; int64_t src_block_number = pair.first;
for (int dst_block_number : pair.second) { for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number); block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number); block_mapping_vec.push_back(dst_block_number);
} }
} }
int* block_mapping_array = block_mapping_vec.data(); int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2; int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU. // Move the data structures to the GPU.
@@ -120,7 +120,7 @@ void copy_blocks(
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob( torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device); block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
@@ -132,7 +132,7 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int>(), block_mapping_tensor.data_ptr<int64_t>(),
numel_per_block); numel_per_block);
})); }));
} }
@@ -141,43 +141,48 @@ namespace vllm {
template<typename scalar_t> template<typename scalar_t>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int key_stride,
const int value_stride, const int value_stride,
const int num_heads, const int num_heads,
const int head_size, const int head_size,
const int block_size, const int block_size,
const int x) { const int x) {
const int token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size; if (slot_idx < 0) {
const int block_offset = slot_idx % block_size; // Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size; const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) { for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int src_key_idx = token_idx * key_stride + i; const int64_t src_key_idx = token_idx * key_stride + i;
const int src_value_idx = token_idx * value_stride + i; const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x + x_idx * block_size * x
+ block_offset * x + block_offset * x
+ x_offset; + x_offset;
const int tgt_value_idx = block_idx * num_heads * head_size * block_size const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size + head_idx * head_size * block_size
+ head_offset * block_size + head_offset * block_size
+ block_offset; + block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); key_cache[tgt_key_idx] = key[src_key_idx];
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); value_cache[tgt_value_idx] = value[src_value_idx];
} }
} }
@@ -211,7 +216,7 @@ void reshape_and_cache(
value.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(), slot_mapping.data_ptr<int64_t>(),
key_stride, key_stride,
value_stride, value_stride,
num_heads, num_heads,

View File

@@ -6,9 +6,19 @@ void rms_norm(
torch::Tensor& weight, torch::Tensor& weight,
float epsilon); float epsilon);
void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"rms_norm", "rms_norm",
&rms_norm, &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor."); "Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
} }

View File

@@ -9,8 +9,8 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template<typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon,
const int num_tokens, const int num_tokens,
@@ -34,15 +34,45 @@ __global__ void rms_norm_kernel(
} }
} }
// TODO: Further optimize this kernel.
template<typename scalar_t>
__global__ void fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
x += (float) residual[blockIdx.x * hidden_size + idx];
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size] torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [num_tokens, hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { float epsilon) {
int num_tokens = input.size(0); int hidden_size = input.size(-1);
int hidden_size = input.size(1); int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
@@ -60,3 +90,28 @@ void rms_norm(
hidden_size); hidden_size);
}); });
} }
void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}

View File

@@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
template<typename scalar_t, bool IS_NEOX> template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int query_stride, const int query_stride,
@@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int num_tokens = query.size(0); int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size; int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(1) / head_size; int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(0); int query_stride = query.stride(-2);
int key_stride = key.stride(0); int key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));

View File

@@ -7,9 +7,13 @@ torch::Tensor awq_gemm(
torch::Tensor _zeros, torch::Tensor _zeros,
int split_k_iters); int split_k_iters);
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
"awq_gemm", m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
&awq_gemm,
"Quantized GEMM for AWQ");
} }

View File

@@ -0,0 +1,148 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace vllm {
namespace squeezellm {
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
const half2* __restrict__ vec,
const int* __restrict__ mat,
half2* __restrict__ mul,
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ half2 blockvec[blockwidth2];
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}
__half res;
half2 res2;
half2 tmp2;
int i;
int k;
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b){
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
res2 = {};
tmp2 = {};
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
res = __hadd(__hadd(res2.x, res2.y), res);
i += width;
k += 4;
}
// col%2 -> only set one of the two values
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
atomicAdd(&mul[b * width / 2 + col / 2], res3);
}
}
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
(half2*) vec.data<at::Half>(),
mat.data_ptr<int>(),
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
height, width, batch, vec_height
);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4

View File

@@ -40,6 +40,16 @@ Initialize vLLM's engine for offline inference with the ``LLM`` class and the `O
llm = LLM(model="facebook/opt-125m") llm = LLM(model="facebook/opt-125m")
Use model from www.modelscope.cn
.. code-block:: shell
export VLLM_USE_MODELSCOPE=True
.. code-block:: python
llm = LLM(model="qwen/Qwen-7B-Chat", revision="v1.1.8", trust_remote_code=True)
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens. Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
.. code-block:: python .. code-block:: python
@@ -67,6 +77,16 @@ Start the server:
$ python -m vllm.entrypoints.api_server $ python -m vllm.entrypoints.api_server
Use model from www.modelscope.cn
.. code-block:: console
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.api_server \
$ --model="qwen/Qwen-7B-Chat" \
$ --revision="v1.1.8" \
$ --trust-remote-code
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model. By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
Query the model in shell: Query the model in shell:
@@ -95,6 +115,13 @@ Start the server:
$ python -m vllm.entrypoints.openai.api_server \ $ python -m vllm.entrypoints.openai.api_server \
$ --model facebook/opt-125m $ --model facebook/opt-125m
Use model from www.modelscope.cn
.. code-block:: console
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints. By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
This server can be queried in the same format as OpenAI API. For example, list the models: This server can be queried in the same format as OpenAI API. For example, list the models:

View File

@@ -65,6 +65,7 @@ Documentation
serving/distributed_serving serving/distributed_serving
serving/run_on_sky serving/run_on_sky
serving/deploying_with_triton serving/deploying_with_triton
serving/deploying_with_docker
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
@@ -72,3 +73,9 @@ Documentation
models/supported_models models/supported_models
models/adding_model models/adding_model
.. toctree::
:maxdepth: 1
:caption: Quantization
quantization/auto_awq

View File

@@ -62,31 +62,34 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+) -> SamplerOutput: +) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. 3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture. 4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
.. note:: .. note::
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM. If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
3. (Optional) Implement tensor parallelism support 3. (Optional) Implement tensor parallelism and quantization support
-------------------------------------------------- -------------------------------------------------------------------
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it. If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions. To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
When it comes to the linear layers, you should use either :code:`RowParallelLinear` or :code:`ColumnParallelLinear`. When it comes to the linear layers, we provide the following options to parallelize them:
Typically, :code:`ColumnParallelLinear` is used for QKV linear layers and the first linear layers of the MLP blocks.
For the remaining linear layers, :code:`RowParallelLinear` is used.
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
* :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
* :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
* :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
4. Implement the weight loading logic 4. Implement the weight loading logic
------------------------------------- -------------------------------------
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class. You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
While the process is straightforward for most layers, the tensor-parallel layers necessitate some additional care as their weights should be partitioned to multiple GPUs.
5. Register your model 5. Register your model
---------------------- ----------------------

View File

@@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`BaiChuanForCausalLM` * - :code:`BaiChuanForCausalLM`
- Baichuan - Baichuan
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc. - :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
* - :code:`ChatGLMModel`
- ChatGLM
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
* - :code:`BloomForCausalLM` * - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat - BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
@@ -53,9 +56,15 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`OPTForCausalLM` * - :code:`OPTForCausalLM`
- OPT, OPT-IML - OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
* - :code:`PhiForCausalLM`
- Phi-1.5
- :code:`microsoft/phi-1_5`, etc.
* - :code:`QWenLMHeadModel` * - :code:`QWenLMHeadModel`
- Qwen - Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
* - :code:`YiForCausalLM`
- Yi
- :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model. Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
@@ -72,4 +81,18 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
output = llm.generate("Hello, my name is") output = llm.generate("Hello, my name is")
print(output) print(output)
To use model from www.modelscope.cn
.. code-block:: shell
$ export VLLM_USE_MODELSCOPE=True
.. code-block:: python
from vllm import LLM
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
If vLLM successfully generates text, it indicates that your model is supported. If vLLM successfully generates text, it indicates that your model is supported.

View File

@@ -0,0 +1,69 @@
.. _auto_awq:
AutoAWQ
==================
To create a new 4-bit quantized model, you can leverage `AutoAWQ <https://github.com/casper-hansen/AutoAWQ>`_.
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
The main benefits are lower latency and memory usage.
You can quantize your own models by installing AutoAWQ or picking one of the `400+ models on Huggingface <https://huggingface.co/models?sort=trending&search=awq>`_.
.. code-block:: console
$ pip install autoawq
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5:
.. code-block:: python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ <https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ>`_ with the following command:
.. code-block:: console
$ python examples/llm_engine_example.py --model TheBloke/Llama-2-7b-Chat-AWQ --quantization awq
AWQ models are also supported directly through the LLM entrypoint:
.. code-block:: python
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="TheBloke/Llama-2-7b-Chat-AWQ", quantization="AWQ")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@@ -0,0 +1,21 @@
.. _deploying_with_docker:
Deploying with Docker
============================
You can build and run vLLM from source via the provided dockerfile. To build vLLM:
.. code-block:: console
$ DOCKER_BUILDKIT=1 docker build . --target vllm --tag vllm --build-arg max_jobs=8
To run vLLM:
.. code-block:: console
$ docker run --runtime nvidia --gpus all \
-v ~/.cache/huggingface:/root/.cache/huggingface \
-p 8000:8000 \
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
vllm <args...>

View File

@@ -1,15 +1,12 @@
import argparse import argparse
from typing import List, Tuple
from vllm import EngineArgs, LLMEngine, SamplingParams from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
def main(args: argparse.Namespace): def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
# Parse the CLI argument and initialize the engine. """Create a list of test prompts with their sampling parameters."""
engine_args = EngineArgs.from_cli_args(args) return [
engine = LLMEngine.from_engine_args(engine_args)
# Test the following prompts.
test_prompts = [
("A robot may not injure a human being", ("A robot may not injure a human being",
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
("To be or not to be,", ("To be or not to be,",
@@ -25,22 +22,36 @@ def main(args: argparse.Namespace):
temperature=0.0)), temperature=0.0)),
] ]
# Run the engine by calling `engine.step()` manually.
def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
while True:
# To test continuous batching, we add one request at each step. while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params) engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1 request_id += 1
request_outputs = engine.step() request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished: if request_output.finished:
print(request_output) print(request_output)
if not (engine.has_unfinished_requests() or test_prompts):
break def initialize_engine(args: argparse.Namespace) -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
engine_args = EngineArgs.from_cli_args(args)
return LLMEngine.from_engine_args(engine_args)
def main(args: argparse.Namespace):
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine(args)
test_prompts = create_test_prompts()
process_requests(engine, test_prompts)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -93,9 +93,43 @@ echo 'vLLM yapf: Done'
# echo 'vLLM mypy:' # echo 'vLLM mypy:'
# mypy # mypy
# Lint specified files
lint() {
pylint "$@"
}
# Lint files that differ from main branch. Ignores dirs that are not slated
# for autolint yet.
lint_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause pylint to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
pylint
fi
}
# Run Pylint # Run Pylint
echo 'vLLM Pylint:' echo 'vLLM Pylint:'
pylint vllm tests ## This flag lints individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
lint "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
lint vllm tests
else
# Format only the files that changed in last commit.
lint_changed
fi
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.' echo 'Reformatted files. Please review and stage the changes.'

View File

@@ -3,7 +3,7 @@ requires = [
"ninja", "ninja",
"packaging", "packaging",
"setuptools", "setuptools",
"torch == 2.0.1", "torch >= 2.1.0",
"wheel", "wheel",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -12,3 +12,4 @@ types-setuptools
pytest pytest
pytest-forked pytest-forked
pytest-asyncio pytest-asyncio

View File

@@ -5,9 +5,10 @@ pandas # Required for Ray data.
pyarrow # Required for Ray data. pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch == 2.0.1 einops # Required for phi-1_5
torch >= 2.1.0
transformers >= 4.34.0 # Required for Mistral. transformers >= 4.34.0 # Required for Mistral.
xformers == 0.0.22 # Required for Mistral. xformers >= 0.0.22.post7 # Required for CUDA 12.1.
fastapi fastapi
uvicorn[standard] uvicorn[standard]
pydantic < 2 # Required for OpenAI server. pydantic == 1.10.13 # Required for OpenAI server.

View File

@@ -12,6 +12,8 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures. # Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
@@ -200,6 +202,7 @@ quantization_extension = CUDAExtension(
sources=[ sources=[
"csrc/quantization.cpp", "csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu", "csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": CXX_FLAGS, "cxx": CXX_FLAGS,
@@ -224,7 +227,7 @@ def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath) return os.path.join(ROOT_DIR, *filepath)
def find_version(filepath: str): def find_version(filepath: str) -> str:
"""Extract version information from the given filepath. """Extract version information from the given filepath.
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
@@ -237,9 +240,22 @@ def find_version(filepath: str):
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py"))
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
return version
def read_readme() -> str: def read_readme() -> str:
"""Read the README file.""" """Read the README file if present."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read() p = get_path("README.md")
if os.path.isfile(p):
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
else:
return ""
def get_requirements() -> List[str]: def get_requirements() -> List[str]:
@@ -251,7 +267,7 @@ def get_requirements() -> List[str]:
setuptools.setup( setuptools.setup(
name="vllm", name="vllm",
version=find_version(get_path("vllm", "__init__.py")), version=get_vllm_version(),
author="vLLM Team", author="vLLM Team",
license="Apache 2.0", license="Apache 2.0",
description=("A high-throughput and memory-efficient inference and " description=("A high-throughput and memory-efficient inference and "
@@ -277,4 +293,5 @@ setuptools.setup(
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
package_data={"vllm": ["py.typed"]},
) )

0
tests/__init__.py Normal file
View File

View File

@@ -2,7 +2,7 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`. Run `pytest tests/distributed/test_comm_ops.py --forked`.
""" """
from multiprocessing import Process from multiprocessing import Process, set_start_method
import pytest import pytest
import torch import torch
@@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@pytest.mark.parametrize("test_target", @pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker]) [all_reduce_test_worker, all_gather_test_worker])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
set_start_method("spawn", force=True)
distributed_init_port = get_open_port() distributed_init_port = get_open_port()
processes = [] processes = []
for rank in range(tensor_parallel_size): for rank in range(tensor_parallel_size):

View File

@@ -13,7 +13,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
# - 512 as a buffer # - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 128 # Arbitrary values for testing NUM_BLOCKS = 40000 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]

View File

@@ -6,13 +6,13 @@ import torch
from vllm import cache_ops from vllm import cache_ops
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_TOKENS = [83] # Arbitrary values for testing
NUM_LAYERS = [5] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256] HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32] BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024] # Arbitrary values for testing NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
@@ -69,9 +69,9 @@ def test_copy_blocks(
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
for dst in dsts: for dst in dsts:
for cloned_key_cache in cloned_key_caches: for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst] = cloned_key_cache[src] cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches: for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst] = cloned_value_cache[src] cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
@@ -106,7 +106,7 @@ def test_reshape_and_cache(
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda") slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")
qkv = torch.randn(num_tokens, qkv = torch.randn(num_tokens,
3, 3,

View File

@@ -6,14 +6,16 @@ import pytest
MODELS = [ MODELS = [
"facebook/opt-125m", "facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"tiiuae/falcon-7b",
"gpt2", "gpt2",
"bigcode/tiny_starcoder_py", "bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b", "EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m", "EleutherAI/pythia-70m",
"bigscience/bloom-560m", "bigscience/bloom-560m",
"mosaicml/mpt-7b", "mosaicml/mpt-7b",
"tiiuae/falcon-7b", "microsoft/phi-1_5",
"meta-llama/Llama-2-7b-hf",
] ]

View File

@@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
continue continue
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens assert nth_output.output_token in expected_tokens
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, worker = _prepare_test(batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx

27
tests/test_regression.py Normal file
View File

@@ -0,0 +1,27 @@
"""Containing tests that check for regressions in vLLM's behavior.
It should include tests that are reported by users and making sure they
will never happen again.
"""
from vllm import LLM, SamplingParams
def test_duplicated_ignored_sequence_group():
"""https://github.com/vllm-project/vllm/issues/1655"""
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=256)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(prompts) == len(outputs)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -0,0 +1,44 @@
# pylint: disable=protected-access
import random
import torch
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
def test_worker_prepare_inputs_for_prompt():
worker = Worker(None, None, None)
worker.block_size = 16
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (worker.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))
expected_selected_token_indices = []
selected_token_start_idx = 0
max_seq_len = max(prompt_lens)
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
seq_group_metadata_list)
assert input_tokens.shape == input_positions.shape == (batch_size,
max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
actual = input_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)

View File

@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.2.1" __version__ = "0.2.2"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -1,4 +1,5 @@
from typing import Optional from typing import Optional, Union
import os
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
@@ -58,7 +59,7 @@ class ModelConfig:
trust_remote_code: bool, trust_remote_code: bool,
download_dir: Optional[str], download_dir: Optional[str],
load_format: str, load_format: str,
dtype: str, dtype: Union[str, torch.dtype],
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
@@ -76,7 +77,18 @@ class ModelConfig:
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code, revision) if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
model_path = snapshot_download(model_id=model,
cache_dir=download_dir,
revision=revision)
self.model = model_path
self.download_dir = model_path
self.tokenizer = model_path
self.hf_config = get_config(self.model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config, self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len) max_model_len)
@@ -103,15 +115,31 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq"] supported_quantization = ["awq", "squeezellm"]
if self.quantization is None: if self.quantization is not None:
return self.quantization = self.quantization.lower()
quantization = self.quantization.lower()
if quantization not in supported_quantization: # Parse quantization method from the HF model config, if available.
raise ValueError( hf_quant_config = getattr(self.hf_config, "quantization_config", None)
f"Unknown quantization: {self.quantization}. Must be one of " if hf_quant_config is not None:
f"{supported_quantization}.") hf_quant_method = str(hf_quant_config["quant_method"]).lower()
self.quantization = quantization if self.quantization is None:
self.quantization = hf_quant_method
elif self.quantization != hf_quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({hf_quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization}).")
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
@@ -140,8 +168,8 @@ class ModelConfig:
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the number of KV heads per GPU worker.""" """Returns the total number of KV heads."""
# For GPTBigCode & Falcon: # For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the # NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of # multi_query flag is ignored and we use n_head_kv for the number of
@@ -155,19 +183,34 @@ class ModelConfig:
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case. # Currently, tensor parallelism is not supported in this case.
return 1 return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None: attributes = [
return (self.hf_config.n_head_kv // # For Falcon:
parallel_config.tensor_parallel_size) "n_head_kv",
if getattr(self.hf_config, "num_kv_heads", None) is not None: "num_kv_heads",
return (self.hf_config.num_kv_heads // # For LLaMA-2:
parallel_config.tensor_parallel_size) "num_key_value_heads",
# For LLaMA-2: # For ChatGLM:
if getattr(self.hf_config, "num_key_value_heads", None) is not None: "multi_query_group_num",
return (self.hf_config.num_key_value_heads // ]
parallel_config.tensor_parallel_size) for attr in attributes:
total_num_attention_heads = self.hf_config.num_attention_heads num_kv_heads = getattr(self.hf_config, attr, None)
return total_num_attention_heads // parallel_config.tensor_parallel_size if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_config.num_hidden_layers
@@ -268,6 +311,7 @@ class SchedulerConfig:
iteration. iteration.
max_model_len: Maximum length of a sequence (including prompt max_model_len: Maximum length of a sequence (including prompt
and generated text). and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
""" """
def __init__( def __init__(
@@ -275,6 +319,7 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int], max_num_batched_tokens: Optional[int],
max_num_seqs: int, max_num_seqs: int,
max_model_len: int, max_model_len: int,
max_paddings: int,
) -> None: ) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
@@ -284,6 +329,7 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
@@ -313,7 +359,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
def _get_and_verify_dtype( def _get_and_verify_dtype(
config: PretrainedConfig, config: PretrainedConfig,
dtype: str, dtype: Union[str, torch.dtype],
) -> torch.dtype: ) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
@@ -321,17 +367,23 @@ def _get_and_verify_dtype(
if config_dtype is None: if config_dtype is None:
config_dtype = torch.float32 config_dtype = torch.float32
dtype = dtype.lower() if isinstance(dtype, str):
if dtype == "auto": dtype = dtype.lower()
if config_dtype == torch.float32: if dtype == "auto":
# Following the common practice, we use float16 for float32 models. if config_dtype == torch.float32:
torch_dtype = torch.float16 # Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else: else:
torch_dtype = config_dtype if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else: else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}")
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
# Verify the dtype. # Verify the dtype.
if torch_dtype != config_dtype: if torch_dtype != config_dtype:
@@ -361,6 +413,8 @@ def _get_and_verify_max_len(
"n_positions", "n_positions",
# MPT # MPT
"max_seq_len", "max_seq_len",
# ChatGLM2
"seq_length",
# Others # Others
"max_sequence_length", "max_sequence_length",
"max_seq_length", "max_seq_length",
@@ -387,6 +441,9 @@ def _get_and_verify_max_len(
if rope_scaling is not None: if rope_scaling is not None:
assert "factor" in rope_scaling assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor
if max_model_len is None: if max_model_len is None:

View File

@@ -131,7 +131,8 @@ class Scheduler:
# requests in the generation phase. # requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
num_batched_tokens = 0 seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
# are added to the back. # are added to the back.
@@ -157,7 +158,9 @@ class Scheduler:
break break
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens > new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
@@ -168,10 +171,14 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens
seq_group = self.waiting.pop(0) seq_group = self.waiting.pop(0)
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
@@ -179,7 +186,7 @@ class Scheduler:
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
prompt_run=True, prompt_run=True,
num_batched_tokens=num_batched_tokens, num_batched_tokens=len(seq_lens) * max(seq_lens),
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
@@ -268,7 +275,7 @@ class Scheduler:
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for seq_group in scheduler_outputs.scheduled_seq_groups:
seq_data: Dict[int, List[SequenceData]] = {} seq_data: Dict[int, SequenceData] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id seq_id = seq.seq_id

View File

@@ -27,6 +27,7 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
@@ -156,6 +157,10 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
@@ -163,7 +168,7 @@ class EngineArgs:
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', None], choices=['awq', 'squeezellm', None],
default=None, default=None,
help='Method used to quantize the weights') help='Method used to quantize the weights')
return parser return parser
@@ -193,7 +198,8 @@ class EngineArgs:
self.worker_use_ray) self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len) model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config

View File

@@ -142,10 +142,10 @@ class RequestTracker:
self._request_streams[request_id].finish() self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]: def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be """Get the new requests and finished requests to be
sent to the engine.""" sent to the engine."""
new_requests: List[dict] = [] new_requests: List[Dict] = []
finished_requests: Set[str] = set() finished_requests: Set[str] = set()
while not self._finished_requests.empty(): while not self._finished_requests.empty():
@@ -206,18 +206,17 @@ class _AsyncLLMEngine(LLMEngine):
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = [] coros = []
for worker in self.workers: for worker in self.workers:
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method) coros.append(
worker.execute_method.remote(method, *args, **kwargs))
else: else:
executor = getattr(worker, method) executor = getattr(worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs)))
output = executor(*args, **kwargs) all_outputs = await asyncio.gather(*coros)
all_outputs.append(output)
if self.parallel_config.worker_use_ray:
all_outputs = await asyncio.gather(*all_outputs)
if get_all_outputs: if get_all_outputs:
return all_outputs return all_outputs
@@ -484,7 +483,7 @@ class AsyncLLMEngine:
distributed_init_method, placement_group = initialize_cluster( distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray) parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, *engine_configs,
distributed_init_method, distributed_init_method,

View File

@@ -567,7 +567,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )
return self._process_model_outputs(output, scheduler_outputs) + ignored return self._process_model_outputs(output, scheduler_outputs)
def _log_system_stats( def _log_system_stats(
self, self,
@@ -632,8 +632,7 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now self.last_logging_time = now
def _decode_sequence(self, seq: Sequence, def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence.""" """Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset, (new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
@@ -642,7 +641,8 @@ class LLMEngine:
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset, read_offset=seq.read_offset,
skip_special_tokens=sampling_params.skip_special_tokens, skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
) )
if seq.tokens is None: if seq.tokens is None:
seq.tokens = new_tokens seq.tokens = new_tokens

View File

@@ -17,6 +17,12 @@ app = FastAPI()
engine = None engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate") @app.post("/generate")
async def generate(request: Request) -> Response: async def generate(request: Request) -> Response:
"""Generate completion for the request. """Generate completion for the request.

View File

@@ -13,7 +13,7 @@ import uvicorn
from fastapi import Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse, Response
from packaging import version from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@@ -145,6 +145,12 @@ async def check_length(
return input_ids, None return input_ids, None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
@@ -212,6 +218,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic()) created_time = int(time.monotonic())
try: try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=request.n, n=request.n,
presence_penalty=request.presence_penalty, presence_penalty=request.presence_penalty,
@@ -226,6 +233,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
ignore_eos=request.ignore_eos, ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search, use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens, skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
) )
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -237,6 +245,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
index: int, index: int,
text: str, text: str,
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str: ) -> str:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
@@ -249,7 +258,10 @@ async def create_chat_completion(request: ChatCompletionRequest,
model=model_name, model=model_name,
choices=[choice_data], choices=[choice_data],
) )
response_json = response.json(ensure_ascii=False) if usage is not None:
response.usage = usage
# exclude unset to leave details out of each sse
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json return response_json
@@ -275,17 +287,25 @@ async def create_chat_completion(request: ChatCompletionRequest,
i = output.index i = output.index
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids) completion_tokens = len(output.token_ids)
previous_num_tokens[i] = completion_tokens
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text=delta_text, text=delta_text,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None:
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
usage=final_usage,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@@ -413,6 +433,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
created_time = int(time.monotonic()) created_time = int(time.monotonic())
try: try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=request.n, n=request.n,
best_of=request.best_of, best_of=request.best_of,
@@ -428,6 +449,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
logprobs=request.logprobs, logprobs=request.logprobs,
use_beam_search=request.use_beam_search, use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens, skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
) )
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -452,6 +474,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
text: str, text: str,
logprobs: Optional[LogProbs] = None, logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str: ) -> str:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
@@ -465,7 +488,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
model=model_name, model=model_name,
choices=[choice_data], choices=[choice_data],
) )
response_json = response.json(ensure_ascii=False) if usage is not None:
response.usage = usage
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json return response_json
@@ -495,11 +520,19 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if output.finish_reason is not None: if output.finish_reason is not None:
logprobs = (LogProbs() logprobs = (LogProbs()
if request.logprobs is not None else None) if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
usage=final_usage,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@@ -615,9 +648,10 @@ if __name__ == "__main__":
max_model_len = engine_model_config.max_model_len max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer = get_tokenizer(
tokenizer_mode=engine_args.tokenizer_mode, engine_model_config.tokenizer,
trust_remote_code=engine_args.trust_remote_code) tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,

View File

@@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
@@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
class LogProbs(BaseModel): class LogProbs(BaseModel):
@@ -137,6 +139,7 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo]
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@@ -176,3 +179,5 @@ class ChatCompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(
default=None, description="data about request and response")

View File

@@ -48,4 +48,9 @@ _setup_logger()
def init_logger(name: str): def init_logger(name: str):
return logging.getLogger(name) # Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(_default_handler)
logger.propagate = False
return logger

View File

@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from xformers.ops import AttentionBias from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
@@ -29,6 +29,8 @@ class InputMetadata:
context_lens: torch.Tensor, context_lens: torch.Tensor,
max_context_len: int, max_context_len: int,
block_tables: torch.Tensor, block_tables: torch.Tensor,
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
@@ -38,12 +40,15 @@ class InputMetadata:
self.context_lens = context_lens self.context_lens = context_lens
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None self.to_cache = None
if sliding_window is not None: if sliding_window is not None:
# We need to keep the positions of sliding windows within # We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which # the key / value tables, this is helpful to know which
# elements we need to cache and where # elements we need to cache.
to_cache, start_idx = [], 0 to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens: for prompt_len in self.prompt_lens:
to_cache.extend( to_cache.extend(
@@ -51,36 +56,36 @@ class InputMetadata:
start_idx + max(0, prompt_len - sliding_window), start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len, start_idx + prompt_len,
)) ))
start_idx += prompt_len start_idx += self.max_prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0])) to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache, self.to_cache = torch.tensor(to_cache,
dtype=torch.int32, dtype=torch.int32,
device=self.slot_mapping.device) device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0: if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1] self.max_num_blocks_per_seq = block_tables.shape[1]
else: else:
self.max_num_blocks_per_seq = 0 self.max_num_blocks_per_seq = 0
assert block_tables.shape[0] == self.num_generation_tokens assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = [] self.attn_bias: Optional[AttentionBias] = None
def __repr__(self) -> str: def __repr__(self) -> str:
# Print only useful metadata. # Print only useful metadata.
return (f'InputMetadata(' return (
f'num_valid_tokens={self.num_valid_tokens}, ' f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, ' f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, ' f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, ' f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, ' f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, ' f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), ' f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), ' f'block_tables={self.block_tables}, '
f'slot_mapping={self.slot_mapping}') f'selected_token_indices={self.selected_token_indices}, '
f'categorized_sample_indices={self.categorized_sample_indices}, '
f'slot_mapping={self.slot_mapping})')

View File

@@ -1,24 +1,27 @@
"""Custom activation functions.""" """Custom activation functions."""
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import activation_ops from vllm import activation_ops
from vllm.model_executor.layers.quantization import QuantizationConfig
class SiluAndMul(nn.Module): class SiluAndMul(nn.Module):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes: Shapes:
x: (num_tokens, 2 * d) x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (num_tokens, d) return: (batch_size, seq_len, d) or (num_tokens, d)
""" """
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0] d = x.shape[-1] // 2
d = x.shape[1] // 2 output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x) activation_ops.silu_and_mul(out, x)
return out return out
@@ -26,9 +29,7 @@ class SiluAndMul(nn.Module):
class NewGELU(nn.Module): class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0] out = torch.empty_like(x)
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_new(out, x) activation_ops.gelu_new(out, x)
return out return out
@@ -36,13 +37,32 @@ class NewGELU(nn.Module):
class FastGELU(nn.Module): class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0] out = torch.empty_like(x)
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_fast(out, x) activation_ops.gelu_fast(out, x)
return out return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
hidden_size: int,
params_dtype: torch.dtype,
):
super().__init__()
self.act = act_module
self.scales = nn.Parameter(
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
def forward(self, x: torch.Tensor):
return self.act(x) / self.scales
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = {
"gelu": nn.GELU(), "gelu": nn.GELU(),
"gelu_fast": FastGELU(), "gelu_fast": FastGELU(),
@@ -52,9 +72,27 @@ _ACTIVATION_REGISTRY = {
} }
def get_act_fn(act_fn: str) -> nn.Module: def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
) -> nn.Module:
"""Get an activation function by name.""" """Get an activation function by name."""
act_fn = act_fn.lower() act_fn_name = act_fn_name.lower()
if act_fn in _ACTIVATION_REGISTRY: if act_fn_name not in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn] raise ValueError(
raise ValueError(f"Activation function {act_fn!r} is not supported.") f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if quant_config is not None:
if act_fn_name in quant_config.get_scaled_act_names():
if intermediate_size is None:
raise ValueError(
"intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(
act_fn,
intermediate_size,
params_dtype=torch.get_default_dtype(),
)
return act_fn

View File

@@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import get_rope
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding)
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -23,25 +21,9 @@ class PagedAttention(nn.Module):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes query, key, and value tensors as input. The input tensors
input 1D tensors can either contain prompt tokens or generation tokens, in can either contain prompt tokens or generation tokens, in addition to
addition to paddings. paddings.
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following: The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does 1. Perform multi_query_kv_attention for the prompts. This operation does
@@ -53,7 +35,7 @@ class PagedAttention(nn.Module):
4. Perform single_query_cached_kv_attention for the generation tokens. 4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV This operation reads the previous key and value tensors from the KV
cache. cache.
5. Output a flattened 1D tensor. 5. Return the output tensor.
""" """
def __init__(self, def __init__(self,
@@ -85,14 +67,15 @@ class PagedAttention(nn.Module):
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
del dtype # Unused. del dtype # Unused.
if input_metadata.attn_bias: if input_metadata.attn_bias is not None:
# Already set by a previous layer. # Already set by a previous layer.
return return
prompt_lens = input_metadata.prompt_lens prompt_lens = [input_metadata.max_prompt_len
] * input_metadata.num_prompts
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
if self.sliding_window is not None: if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(self.sliding_window) attn_bias = attn_bias.make_local_attention(self.sliding_window)
input_metadata.attn_bias.append(attn_bias) input_metadata.attn_bias = attn_bias
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
@@ -111,7 +94,6 @@ class PagedAttention(nn.Module):
value: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
""" """
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads. # Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
@@ -124,7 +106,7 @@ class PagedAttention(nn.Module):
query.unsqueeze(0), query.unsqueeze(0),
key.unsqueeze(0), key.unsqueeze(0),
value.unsqueeze(0), value.unsqueeze(0),
attn_bias=input_metadata.attn_bias[0], attn_bias=input_metadata.attn_bias,
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
) )
@@ -172,7 +154,9 @@ class PagedAttention(nn.Module):
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 # For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
attention_ops.paged_attention_v1( attention_ops.paged_attention_v1(
@@ -232,12 +216,12 @@ class PagedAttention(nn.Module):
"""PagedAttention forward pass. """PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size]. tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [batch_size, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
@@ -246,9 +230,9 @@ class PagedAttention(nn.Module):
cache_event: event to wait for the cache operations to finish. cache_event: event to wait for the cache operations to finish.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
batch_size, seq_len, _ = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
@@ -264,10 +248,10 @@ class PagedAttention(nn.Module):
assert input_metadata.num_generation_tokens == 0 assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata, dtype=query.dtype) self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output,
query[:num_prompt_tokens], query,
key[:num_prompt_tokens], key,
value[:num_prompt_tokens], value,
input_metadata, input_metadata,
) )
@@ -278,13 +262,10 @@ class PagedAttention(nn.Module):
# Reshape the keys and values and store them in the cache. # Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key # When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached. # and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens if key_cache is not None and value_cache is not None:
if (num_valid_tokens > 0 and key_cache is not None key_to_cache = key
and value_cache is not None): value_to_cache = value
# The stride is 3 because the key and value are sliced from qkv. slot_mapping = input_metadata.slot_mapping.view(-1)
key_to_cache = key[:num_valid_tokens]
value_to_cache = value[:num_valid_tokens]
slot_mapping = input_metadata.slot_mapping
if input_metadata.to_cache is not None: if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache] key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache] value_to_cache = value_to_cache[input_metadata.to_cache]
@@ -305,14 +286,14 @@ class PagedAttention(nn.Module):
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens.") "generating tokens.")
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(output, query, key_cache,
output[num_prompt_tokens:num_valid_tokens], value_cache, input_metadata,
query[num_prompt_tokens:num_valid_tokens], key_cache, self.get_alibi_slopes())
value_cache, input_metadata, self.get_alibi_slopes())
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, self.num_heads * self.head_size) return output.view(batch_size, seq_len,
self.num_heads * self.head_size)
class PagedAttentionWithRoPE(PagedAttention): class PagedAttentionWithRoPE(PagedAttention):
@@ -336,23 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
scale, scale,
num_kv_heads, num_kv_heads,
sliding_window=sliding_window) sliding_window=sliding_window)
if rope_scaling is None: self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, is_neox_style, rope_scaling)
max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LinearScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
@@ -368,10 +334,10 @@ class PagedAttentionWithRoPE(PagedAttention):
""" PagedAttention forward pass with rotary embedding. """ PagedAttention forward pass with rotary embedding.
Args: Args:
positions: shape = [num_tokens] positions: shape = [batch_size, seq_len]
query: shape = [num_tokens, num_heads * head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
@@ -380,7 +346,7 @@ class PagedAttentionWithRoPE(PagedAttention):
cache_event: event to wait for the cache operations to finish. cache_event: event to wait for the cache operations to finish.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
@@ -414,34 +380,34 @@ class PagedAttentionWithALiBi(PagedAttention):
def set_attn_bias(self, input_metadata: InputMetadata, def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None: dtype: torch.dtype) -> None:
if input_metadata.attn_bias: if input_metadata.attn_bias is not None:
# Already set by a previous layer. # Already set by a previous layer.
return return
# Generates ALiBi mask for each prompt. # Generates ALiBi mask based on the max prompt length.
for prompt_len in input_metadata.prompt_lens: max_prompt_len = input_metadata.max_prompt_len
bias = torch.arange(prompt_len, dtype=dtype) bias = torch.arange(max_prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
# paper. # paper.
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device) bias = bias.to(self.alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to # When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8. # be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8 padded_len = (max_prompt_len + 7) // 8 * 8
bias = torch.empty( bias = torch.empty(
1, # batch_size input_metadata.num_prompts,
self.num_heads, self.num_heads,
prompt_len, max_prompt_len,
padded_len, padded_len,
device=self.alibi_slopes.device, device=self.alibi_slopes.device,
dtype=dtype, dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias) )[:, :, :, :max_prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None]) bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias) input_metadata.attn_bias = attn_bias
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
@@ -466,24 +432,19 @@ class PagedAttentionWithALiBi(PagedAttention):
value = torch.repeat_interleave(value, value = torch.repeat_interleave(value,
self.num_queries_per_kv, self.num_queries_per_kv,
dim=1) dim=1)
batch_size = input_metadata.num_prompts
seq_len = input_metadata.max_prompt_len
# FIXME(woosuk): Because xformers does not support dynamic sequence out = xops.memory_efficient_attention_forward(
# lengths with custom attention bias, we process each prompt one by query.view(batch_size, seq_len, self.num_heads, self.head_size),
# one. This is inefficient, especially when we have many short prompts. key.view(batch_size, seq_len, self.num_heads, self.head_size),
start = 0 value.view(batch_size, seq_len, self.num_heads, self.head_size),
for i, prompt_len in enumerate(input_metadata.prompt_lens): attn_bias=input_metadata.attn_bias,
end = start + prompt_len p=0.0,
out = xops.memory_efficient_attention_forward( scale=self.scale,
query[None, start:end], )
key[None, start:end], # TODO(woosuk): Unnecessary copy. Optimize.
value[None, start:end], output.copy_(out.view(-1, self.num_heads, self.head_size))
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]: def get_alibi_slopes(self) -> Optional[torch.Tensor]:

View File

@@ -1,4 +1,6 @@
"""Custom normalization layers.""" """Custom normalization layers."""
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -21,7 +23,19 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
layernorm_ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x) out = torch.empty_like(x)
layernorm_ops.rms_norm( layernorm_ops.rms_norm(
out, out,

View File

@@ -0,0 +1,541 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
logger = init_logger(__name__)
class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
"""Create weights for a linear layer."""
raise NotImplementedError
@abstractmethod
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor."""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
Args:
separate_bias_add: If true, add bias separately after matrix
multiplication.
"""
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
weight = Parameter(torch.empty(output_size,
input_size,
device=torch.cuda.current_device(),
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"]
if self.separate_bias_add:
if bias:
return F.linear(x, weight) + bias
return F.linear(x, weight)
return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
self.register_parameter(name, weight)
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0})
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size_per_partition, self.params_dtype)
for name, weight in self.linear_weights.items():
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = param.data
if output_dim is not None:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
"""
def __init__(
self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
shard_offsets = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size_per_partition, self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias

View File

@@ -0,0 +1,22 @@
from typing import Type
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
]

View File

@@ -0,0 +1,158 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
def get_name(self) -> str:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
qzeros = Parameter(
torch.empty(
input_size // self.quant_config.group_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
input_size // self.quant_config.group_size,
output_size,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
return {
"qweight": qweight,
"qzeros": qzeros,
"scales": scales,
}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
qzeros = weights["qzeros"]
scales = weights["scales"]
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)

View File

@@ -1,22 +1,26 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizationConfig:
@classmethod class QuantizationConfig(ABC):
def get_name(cls) -> str: """Base class for quantization configs."""
@abstractmethod
def get_name(self) -> str:
"""Name of the quantization method.""" """Name of the quantization method."""
raise NotImplementedError raise NotImplementedError
@classmethod @abstractmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
"""List of supported activation dtypes.""" """List of supported activation dtypes."""
raise NotImplementedError raise NotImplementedError
@classmethod @abstractmethod
def get_min_capability(cls) -> int: def get_min_capability(self) -> int:
"""Minimum GPU capability to support the quantization method. """Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
@@ -25,12 +29,14 @@ class QuantizationConfig:
""" """
raise NotImplementedError raise NotImplementedError
@classmethod @staticmethod
def get_config_filenames(cls) -> List[str]: @abstractmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory.""" """List of filenames to search for in the model directory."""
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config.""" """Create a config class from the model's quantization config."""
raise NotImplementedError raise NotImplementedError
@@ -44,32 +50,15 @@ class QuantizationConfig:
raise ValueError(f"Cannot find any of {keys} in the model's " raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.") "quantization config.")
@classmethod @abstractmethod
def get_packed_tensor_names(cls) -> List[str]: def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
raise NotImplementedError raise NotImplementedError
@classmethod @abstractmethod
def is_packed(cls, tensor_name: str) -> bool: def get_scaled_act_names(self) -> List[str]:
"""Returns True if a tensor is packed. """Returns the activation function names that should be post-scaled.
A tensor is considered packed if each element in the tensor is a For now, this is only used by AWQ.
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
""" """
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
raise NotImplementedError raise NotImplementedError

View File

@@ -0,0 +1,124 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
def get_name(self) -> str:
return "squeezellm"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
return SqueezeLLMLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(LinearMethodBase):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size // self.quant_config.pack_factor,
output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
lookup_table = Parameter(
torch.empty(
output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
return {
"qweight": qweight,
"lookup_table": lookup_table,
}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
lookup_table)
if bias is not None:
out = out + bias
return out.reshape(out_shape)

View File

@@ -1,37 +0,0 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)

View File

@@ -1,102 +0,0 @@
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)

View File

@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Rotary Positional Embeddings.""" """Rotary Positional Embeddings."""
from typing import Tuple, Union import math
from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -167,3 +168,149 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> float:
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))
# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype,
device: torch.device) -> torch.Tensor:
if low == high:
high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype, device=device) -
low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def _yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float,
device="cuda")) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool,
rope_scaling: Optional[Dict[str, Any]],
) -> RotaryEmbedding:
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb

View File

@@ -47,15 +47,18 @@ class Sampler(nn.Module):
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size) self.vocab_size)
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties( presence_penalties, frequency_penalties, repetition_penalties = (
input_metadata) _get_penalties(input_metadata))
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties) frequency_penalties, repetition_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
@@ -68,13 +71,18 @@ class Sampler(nn.Module):
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
# Apply top-p and top-k truncation. # Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0] assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks) do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k: if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks) logits = _apply_top_p_top_k(logits, top_ps, top_ks)
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
if do_min_p:
logits = _apply_min_p(logits, min_ps)
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
@@ -108,39 +116,22 @@ def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
selected_token_indices: List[int] = [] hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
start_idx = 0 return hidden_states.index_select(0, input_metadata.selected_token_indices)
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(start_idx, start_idx + prompt_len - 1))
selected_token_indices.append(start_idx + prompt_len - 1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, selected_token_indices)
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]: input_metadata: InputMetadata
) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if (i < input_metadata.num_prompts if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the # NOTE: We do not apply presence and frequency penalties for the
@@ -148,9 +139,11 @@ def _get_penalties(
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1) presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
presence_penalties += [p] * len(seq_ids) presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties repetition_penalties += [r] * len(seq_ids)
return presence_penalties, frequency_penalties, repetition_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
@@ -169,11 +162,34 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return output_tokens return output_tokens
def _apply_logits_processors(logits: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits
def _apply_penalties( def _apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
output_tokens: List[List[int]], output_tokens: List[List[int]],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float],
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs, vocab_size = logits.shape num_seqs, vocab_size = logits.shape
for i in range(num_seqs): for i in range(num_seqs):
@@ -181,7 +197,9 @@ def _apply_penalties(
continue continue
p = presence_penalties[i] p = presence_penalties[i]
f = frequency_penalties[i] f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: r = repetition_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
r - 1.0) < _SAMPLING_EPS:
continue continue
break break
else: else:
@@ -205,7 +223,11 @@ def _apply_penalties(
bin_counts.scatter_add_(1, output_tokens_tensor, bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor)) torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
mask = bin_counts > 0
repetition_penalties = torch.tensor(repetition_penalties,
dtype=logits.dtype,
device=logits.device)
frequency_penalties = torch.tensor(frequency_penalties, frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
@@ -213,10 +235,15 @@ def _apply_penalties(
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~mask] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0) logits -= presence_penalties.unsqueeze(dim=1) * mask
return logits return logits
@@ -239,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
return temperatures return temperatures
def _get_top_p_top_k( def _get_top_p_top_k_min_p(
input_metadata: InputMetadata, input_metadata: InputMetadata,
vocab_size: int, vocab_size: int,
) -> Tuple[List[float], List[int]]: ) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = [] top_ps: List[float] = []
top_ks: List[int] = [] top_ks: List[int] = []
min_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p top_p = sampling_params.top_p
min_p = sampling_params.min_p
# k should not be greater than the vocab size. # k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size) top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation. # k=-1 means no truncation.
@@ -257,9 +286,11 @@ def _get_top_p_top_k(
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1) top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
top_ps += [top_p] * len(seq_ids) top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids) top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks min_ps += [min_p] * len(seq_ids)
return top_ps, top_ks, min_ps
def _apply_top_p_top_k( def _apply_top_p_top_k(
@@ -291,6 +322,24 @@ def _apply_top_p_top_k(
return logits return logits
def _apply_min_p(
logits: torch.Tensor,
min_ps: List[float],
) -> torch.Tensor:
"""
Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
"""
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill(tokens_to_remove, -float("inf"))
return logits
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor, logprobs: torch.Tensor,
@@ -407,21 +456,11 @@ def _sample(
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices = input_metadata.categorized_sample_indices
start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group _, sampling_params = seq_group
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need sample, skip
prompt_len = input_metadata.prompt_lens[i]
start_idx += prompt_len - 1
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
categorized_sample_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType: for sampling_type in SamplingType:

View File

@@ -0,0 +1,139 @@
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
vocab_range_from_global_vocab_size(
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")

View File

@@ -18,6 +18,7 @@ _MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM, "BloomForCausalLM": BloomForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"FalconForCausalLM": FalconForCausalLM, "FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
@@ -27,18 +28,16 @@ _MODEL_REGISTRY = {
"LlamaForCausalLM": LlamaForCausalLM, "LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM, "MistralForCausalLM": MistralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM, "MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM, "OPTForCausalLM": OPTForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel, "QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM, "RWForCausalLM": FalconForCausalLM,
"YiForCausalLM": YiForCausalLM,
} }
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
MistralForCausalLM,
]
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
@@ -62,14 +61,12 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config. # Get the (maybe quantized) linear method.
quant_config = None linear_method = None
if model_config.quantization is not None: if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization, quant_config = get_quant_config(model_config.quantization,
model_config.model, model_config.model,
model_config.hf_config,
model_config.download_dir) model_config.download_dir)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
@@ -85,14 +82,12 @@ def get_model(model_config: ModelConfig) -> nn.Module:
f"{model_config.dtype} is not supported for quantization " f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: " f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}") f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: model = model_class(model_config.hf_config, linear_method)
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign

View File

@@ -12,13 +12,17 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mistral import MistralForCausalLM from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
from vllm.model_executor.models.yi import YiForCausalLM
__all__ = [ __all__ = [
"AquilaForCausalLM", "AquilaForCausalLM",
"BaiChuanForCausalLM", "BaiChuanForCausalLM",
"BaichuanForCausalLM", "BaichuanForCausalLM",
"BloomForCausalLM", "BloomForCausalLM",
"ChatGLMForCausalLM",
"FalconForCausalLM", "FalconForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",
@@ -28,6 +32,8 @@ __all__ = [
"LlamaForCausalLM", "LlamaForCausalLM",
"MPTForCausalLM", "MPTForCausalLM",
"OPTForCausalLM", "OPTForCausalLM",
"PhiForCausalLM",
"QWenLMHeadModel", "QWenLMHeadModel",
"MistralForCausalLM", "MistralForCausalLM",
"YiForCausalLM",
] ]

View File

@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -33,15 +33,17 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab, VocabParallelEmbedding, ParallelLMHead)
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.aquila import AquilaConfig
@@ -55,20 +57,17 @@ class AquilaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = RowParallelLinear( hidden_size,
intermediate_size, bias=False,
hidden_size, linear_method=linear_method)
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -110,6 +109,8 @@ class AquilaAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -127,28 +128,29 @@ class AquilaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta, base=self.rope_theta,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) rope_scaling=rope_scaling)
def forward( def forward(
self, self,
@@ -169,10 +171,15 @@ class AquilaAttention(nn.Module):
class AquilaDecoderLayer(nn.Module): class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig): def __init__(
self,
config: AquilaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.self_attn = AquilaAttention( self.self_attn = AquilaAttention(
@@ -181,11 +188,14 @@ class AquilaDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling,
linear_method=linear_method,
) )
self.mlp = AquilaMLP( self.mlp = AquilaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method,
) )
self.input_layernorm = AquilaRMSNorm(config.hidden_size, self.input_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@@ -222,19 +232,22 @@ class AquilaDecoderLayer(nn.Module):
class AquilaModel(nn.Module): class AquilaModel(nn.Module):
def __init__(self, config: AquilaConfig): def __init__(
self,
config: AquilaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
#vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) AquilaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
]) ])
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -267,17 +280,16 @@ class AquilaModel(nn.Module):
class AquilaForCausalLM(nn.Module): class AquilaForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = AquilaModel(config) self.linear_method = linear_method
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.model = AquilaModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -294,79 +306,33 @@ class AquilaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() stacked_params_mapping = [
tensor_model_parallel_rank = get_tensor_model_parallel_rank() # (param_name, shard_name, shard_id)
q_proj_shard_size = (self.config.hidden_size // tp_size) ("qkv_proj", "q_proj", "q"),
kv_proj_shard_size = (self.config.hidden_size // ("qkv_proj", "k_proj", "k"),
self.config.num_attention_heads * ("qkv_proj", "v_proj", "v"),
self.config.num_key_value_heads // tp_size) ("gate_up_proj", "gate_proj", 0),
attention_weight_specs = [ ("gate_up_proj", "up_proj", 1),
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
] ]
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
loaded_weight = loaded_weight[ weight_loader(param, loaded_weight, shard_id)
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
is_gate_up_weight = False default_weight_loader)
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): weight_loader(param, loaded_weight)
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -30,18 +30,20 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi) PagedAttentionWithALiBi)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, hf_model_weights_iterator, VocabParallelEmbedding, ParallelLMHead)
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
@@ -80,20 +82,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = RowParallelLinear( hidden_size,
intermediate_size, bias=False,
hidden_size, linear_method=linear_method)
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -116,6 +115,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str, position_embedding: str,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -131,17 +131,19 @@ class BaiChuanAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear( self.W_pack = QKVParallelLinear(
hidden_size, hidden_size,
3 * hidden_size, self.head_dim,
self.total_num_heads,
self.total_num_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI": if self.postion_embedding == "ALIBI":
@@ -188,7 +190,10 @@ class BaiChuanAttention(nn.Module):
class BaiChuanDecoderLayer(nn.Module): class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str): def __init__(self,
config: BaiChuanConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
@@ -200,11 +205,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@@ -218,10 +225,15 @@ class BaiChuanDecoderLayer(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -229,19 +241,20 @@ class BaiChuanDecoderLayer(nn.Module):
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event, cache_event=cache_event,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.post_attention_layernorm(
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states return hidden_states, residual
return hidden_states
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str): def __init__(self,
config: BaiChuanConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@@ -252,7 +265,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding) BaiChuanDecoderLayer(config, position_embedding, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -266,35 +279,36 @@ class BaiChuanModel(nn.Module):
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
if cache_events is None: if cache_events is None:
cache_event = None cache_event = None
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event, cache_event,
residual,
) )
hidden_states = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class BaiChuanBaseForCausalLM(nn.Module): class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config, position_embedding: str): def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = BaiChuanModel(config, position_embedding) self.linear_method = linear_method
self.lm_head = ColumnParallelLinear( self.model = BaiChuanModel(config, position_embedding, linear_method)
config.hidden_size, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -311,79 +325,46 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() stacked_params_mapping = [
tp_rank = get_tensor_model_parallel_rank() # (param_name, shard_name, shard_id)
state_dict = self.state_dict() ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 2 weight_loader = param.weight_loader
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * weight_loader(param, loaded_weight, shard_id)
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
weight_loader(param, loaded_weight)
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config): def __init__(self,
super().__init__(config, "ALIBI") config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ALIBI", linear_method)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config): def __init__(self,
super().__init__(config, "ROPE") config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)

View File

@@ -30,14 +30,17 @@ from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -70,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class BloomAttention(nn.Module): class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head self.total_num_heads = config.n_head
@@ -81,17 +88,18 @@ class BloomAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0 assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = ColumnParallelLinear( self.query_key_value = QKVParallelLinear(
self.hidden_size, self.hidden_size,
3 * self.hidden_size, self.head_dim,
self.total_num_heads,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
@@ -125,40 +133,49 @@ class BloomAttention(nn.Module):
class BloomMLP(nn.Module): class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
hidden_size, hidden_size,
4 * hidden_size, 4 * hidden_size,
gather_output=False, linear_method=linear_method,
) )
self.act = get_act_fn("gelu") quant_config = getattr(linear_method, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x) x, _ = self.dense_h_to_4h(x)
x = self.act(x) x = self.gelu_impl(x)
x, _ = self.dense_4h_to_h(x) x, _ = self.dense_4h_to_h(x)
return x return x
class BloomBlock(nn.Module): class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config) self.self_attention = BloomAttention(config, linear_method)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config) self.mlp = BloomMLP(config, linear_method)
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm) config.apply_residual_connection_post_layernorm)
@@ -203,7 +220,11 @@ class BloomBlock(nn.Module):
class BloomModel(nn.Module): class BloomModel(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@@ -216,8 +237,10 @@ class BloomModel(nn.Module):
self.embed_dim, eps=config.layer_norm_epsilon) self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList( self.h = nn.ModuleList([
[BloomBlock(config) for _ in range(config.num_hidden_layers)]) BloomBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm # Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -251,12 +274,15 @@ class BloomModel(nn.Module):
class BloomForCausalLM(nn.Module): class BloomForCausalLM(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = BloomModel(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.transformer = BloomModel(config, linear_method)
# parallelism
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -274,55 +300,36 @@ class BloomForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() params_dict = dict(self.named_parameters(remove_duplicate=False))
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight": if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to continue
# load lm_head.weight in parallel. if not name.startswith("transformer."):
self._column_parallel_weights.append(name) name = "transformer." + name
# If lm_head is provided, use it instead. param = params_dict[name]
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of # NOTE: BLOOM's fused QKV's output_dim has the shape of
# [num_heads * 3 * head_size, hidden_size], while the # (num_heads * 3 * head_size), while the
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] output_dim = getattr(param, "output_dim", None)
start = shard_size * tp_rank
end = shard_size * (tp_rank + 1)
loaded_weight = loaded_weight[start:end]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size if output_dim is not None:
head_size = hidden_size // num_heads loaded_weight_shape = loaded_weight.shape
if "query_key_value.weight" in name: loaded_weight = loaded_weight.view(
loaded_weight = loaded_weight.view(-1, 3, head_size, loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
hidden_size) loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(
loaded_weight = loaded_weight.reshape(-1, hidden_size) output_dim, output_dim + 1)
elif "query_key_value.bias" in name: loaded_weight = loaded_weight.reshape(loaded_weight_shape)
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) weight_loader = getattr(param, "weight_loader",
loaded_weight = loaded_weight.reshape(-1) default_weight_loader)
else: weight_loader(param, loaded_weight)
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -0,0 +1,376 @@
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
linear_method=linear_method,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads,
is_neox_style=False,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
key_cache, value_cache = kv_cache
context_layer = self.attn(
position_ids,
q,
k,
v,
key_cache,
value_cache,
input_metadata,
cache_event,
)
attn_output, _ = self.dense(context_layer)
return attn_output
class GLMMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
linear_method=linear_method,
)
self.activation_func = SiluAndMul()
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
linear_method=linear_method,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, _ = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, linear_method)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, linear_method)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = residual + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.mlp(layernorm_output) + residual
return output
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
for i in range(self.num_layers):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
input_metadata=input_metadata,
cache_event=cache_event,
)
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class ChatGLMModel(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, linear_method)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
):
inputs_embeds = self.embedding(input_ids)
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
return hidden_states
class ChatGLMForCausalLM(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -27,20 +27,23 @@ from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (PagedAttention, from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi, PagedAttentionWithALiBi,
PagedAttentionWithRoPE) PagedAttentionWithRoPE)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, from vllm.model_executor.layers.vocab_parallel_embedding import (
hf_model_weights_iterator, VocabParallelEmbedding, ParallelLMHead)
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
@@ -48,19 +51,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
@@ -86,7 +76,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class FalconAttention(nn.Module): class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@@ -103,41 +97,29 @@ class FalconAttention(nn.Module):
if self.new_decoder_architecture: if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads self.total_num_kv_heads = config.num_kv_heads
assert self.total_num_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
elif self.multi_query: elif self.multi_query:
self.total_num_kv_heads = 1 self.total_num_kv_heads = 1
self.num_kv_heads = 1
self.query = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
2 * self.head_dim,
bias=config.bias)
else: else:
self.total_num_kv_heads = self.total_num_heads self.total_num_kv_heads = self.total_num_heads
self.num_kv_heads = self.num_heads if self.total_num_kv_heads >= tp_size:
self.query_key_value = ColumnParallelLinear( # Number of KV heads is greater than TP size, so we partition
self.hidden_size, # the KV heads across multiple tensor parallel GPUs.
(self.total_num_heads + 2 * self.total_num_kv_heads) * assert self.total_num_kv_heads % tp_size == 0
self.head_dim, else:
bias=config.bias, # Number of KV heads is less than TP size, so we replicate
gather_output=False, # the KV heads across multiple tensor parallel GPUs.
skip_bias_add=True, assert tp_size % self.total_num_kv_heads == 0
) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
@@ -149,8 +131,8 @@ class FalconAttention(nn.Module):
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=config.bias, bias=config.bias,
input_is_parallel=True,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary self.use_rotary = config.rotary
@@ -196,18 +178,10 @@ class FalconAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
if not self.new_decoder_architecture and self.multi_query: qkv, bias = self.query_key_value(hidden_states)
q, bias = self.query(hidden_states) if bias is not None:
if bias is not None: qkv += bias
q += bias q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
kv = self.key_value(hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
if self.use_rotary: if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
@@ -221,25 +195,30 @@ class FalconAttention(nn.Module):
class FalconMLP(nn.Module): class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size, self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size, 4 * hidden_size,
bias=config.bias, bias=config.bias,
gather_output=False, skip_bias_add=True,
skip_bias_add=True) linear_method=linear_method)
self.act = nn.GELU() quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn) or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
bias=config.bias, bias=config.bias,
input_is_parallel=True,
skip_bias_add=True, skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results,
linear_method=linear_method)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
@@ -253,12 +232,16 @@ class FalconMLP(nn.Module):
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config) self.self_attention = FalconAttention(config, linear_method)
self.mlp = FalconMLP(config) self.mlp = FalconMLP(config, linear_method)
self.config = config self.config = config
if config.new_decoder_architecture: if config.new_decoder_architecture:
@@ -334,7 +317,11 @@ class FalconDecoderLayer(nn.Module):
class FalconModel(nn.Module): class FalconModel(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@@ -349,7 +336,8 @@ class FalconModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers) FalconDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
]) ])
# Final Layer Norm # Final Layer Norm
@@ -383,15 +371,18 @@ class FalconModel(nn.Module):
class FalconForCausalLM(nn.Module): class FalconForCausalLM(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = FalconModel(config) self.linear_method = linear_method
self.lm_head = ColumnParallelLinear( self.transformer = FalconModel(config, linear_method)
config.hidden_size, self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
bias=False, config.hidden_size,
gather_output=False,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -415,89 +406,44 @@ class FalconForCausalLM(nn.Module):
return next_tokens return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
hidden_size = self.config.hidden_size
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
num_heads = total_num_heads // tp_size
head_size = hidden_size // total_num_heads
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads total_num_kv_heads = self.config.num_kv_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
elif self.config.multi_query: elif self.config.multi_query:
total_num_kv_heads = 1 total_num_kv_heads = 1
num_kv_heads = 1
separated_q_kv = True
kv_head_start = 0
kv_head_end = 1
else: else:
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight) output_dim = getattr(param, "output_dim", None)
loaded_weight_size = loaded_weight.size() loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view( loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2, loaded_weight_shape[:output_dim] +
head_size, *loaded_weight_size[1:]) (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
loaded_weight_shape[output_dim + 1:])
wq = loaded_weight.narrow(
output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) weight_loader = getattr(param, "weight_loader",
wk = loaded_weight[:, [-2]].reshape(-1, default_weight_loader)
*loaded_weight_size[1:]) weight_loader(param, loaded_weight)
wv = loaded_weight[:, [-1]].reshape(-1,
*loaded_weight_size[1:])
wq = wq[head_size * head_start:head_size * head_end]
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
if separated_q_kv:
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("query_key_value", "query")
kv_weight_name = name.replace("query_key_value",
"key_value")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
continue
else:
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -30,15 +30,17 @@ from transformers import GPT2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, hf_model_weights_iterator, VocabParallelEmbedding)
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -46,7 +48,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
@@ -57,17 +63,18 @@ class GPT2Attention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear( self.c_attn = QKVParallelLinear(
self.hidden_size, self.hidden_size,
3 * self.hidden_size, self.head_dim,
total_num_heads,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
@@ -95,6 +102,7 @@ class GPT2MLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@@ -102,15 +110,17 @@ class GPT2MLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
@@ -121,16 +131,20 @@ class GPT2MLP(nn.Module):
class GPT2Block(nn.Module): class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 * inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config) self.attn = GPT2Attention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config) self.mlp = GPT2MLP(inner_dim, config, linear_method)
def forward( def forward(
self, self,
@@ -160,24 +174,23 @@ class GPT2Block(nn.Module):
class GPT2Model(nn.Module): class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList([
[GPT2Block(config) for _ in range(config.num_hidden_layers)]) GPT2Block(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
@@ -207,12 +220,15 @@ class GPT2Model(nn.Module):
class GPT2LMHeadModel(nn.Module): class GPT2LMHeadModel(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = GPT2Model(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.transformer = GPT2Model(config, linear_method)
# parallelism
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -230,19 +246,12 @@ class GPT2LMHeadModel(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_world_size = ( params_dict = dict(self.named_parameters(remove_duplicate=False))
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
@@ -253,53 +262,19 @@ class GPT2LMHeadModel(nn.Module):
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
param = params_dict[name]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name: if conv1d_weight_name not in name:
continue continue
if not name.endswith(".weight"): if not name.endswith(".weight"):
continue continue
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
param = state_dict[name]
if name == "transformer.wte.weight": weight_loader = getattr(param, "weight_loader",
load_padded_tensor_parallel_vocab(param, loaded_weight, default_weight_loader)
tensor_model_parallel_rank) weight_loader(param, loaded_weight)
continue
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -31,15 +31,17 @@ from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, hf_model_weights_iterator, VocabParallelEmbedding)
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -47,7 +49,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
@@ -61,32 +67,26 @@ class GPTBigCodeAttention(nn.Module):
self.multi_query = config.multi_query self.multi_query = config.multi_query
if self.multi_query: if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1 self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else: else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = ColumnParallelLinear( self.c_attn = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size + 2 * self.kv_dim, self.head_dim,
bias=True, total_num_heads,
gather_output=False, total_num_kv_heads,
) bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
@@ -100,17 +100,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
if self.multi_query: qkv, _ = self.c_attn(hidden_states)
q, _ = self.c_attn_q(hidden_states) q, k, v = qkv.split(
kv = self.c_attn_kv(hidden_states) [
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([
self.hidden_size // self.tensor_model_parallel_world_size, self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim self.kv_dim, self.kv_dim
], ],
dim=-1) dim=-1,
)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event) input_metadata, cache_event)
@@ -124,6 +121,7 @@ class GPTBigMLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@@ -131,15 +129,17 @@ class GPTBigMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
@@ -150,16 +150,20 @@ class GPTBigMLP(nn.Module):
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 * inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config) self.attn = GPTBigCodeAttention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config) self.mlp = GPTBigMLP(inner_dim, config, linear_method)
def forward( def forward(
self, self,
@@ -189,23 +193,23 @@ class GPTBigCodeBlock(nn.Module):
class GPTBigCodeModel(nn.Module): class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList([
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) GPTBigCodeBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
@@ -235,12 +239,15 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module): class GPTBigCodeForCausalLM(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = GPTBigCodeModel(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.transformer = GPTBigCodeModel(config, linear_method)
# parallelism
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -258,89 +265,21 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_world_size = ( params_dict = dict(self.named_parameters(remove_duplicate=False))
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue continue
if ".attn.bias" in name: if ".attn.bias" in name:
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
param = params_dict[name]
if not name.startswith("transformer."): weight_loader = getattr(param, "weight_loader",
name = "transformer." + name default_weight_loader)
weight_loader(param, loaded_weight)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -29,14 +29,17 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -44,23 +47,28 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
config.hidden_size, config.hidden_size,
3 * config.hidden_size, self.head_size,
self.total_num_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
@@ -102,20 +110,27 @@ class GPTJAttention(nn.Module):
class GPTJMLP(nn.Module): class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig): def __init__(
self,
intermediate_size: int,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear( self.fc_in = ColumnParallelLinear(
hidden_size, hidden_size,
intermediate_size, intermediate_size,
gather_output=False, linear_method=linear_method,
) )
self.fc_out = RowParallelLinear( self.fc_out = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states) hidden_states, _ = self.fc_in(hidden_states)
@@ -126,15 +141,19 @@ class GPTJMLP(nn.Module):
class GPTJBlock(nn.Module): class GPTJBlock(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
if config.n_inner is None: if config.n_inner is None:
inner_dim = 4 * config.n_embd inner_dim = 4 * config.n_embd
else: else:
inner_dim = config.n_inner inner_dim = config.n_inner
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config) self.attn = GPTJAttention(config, linear_method)
self.mlp = GPTJMLP(inner_dim, config) self.mlp = GPTJMLP(inner_dim, config, linear_method)
def forward( def forward(
self, self,
@@ -160,7 +179,11 @@ class GPTJBlock(nn.Module):
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.n_embd self.embed_dim = config.n_embd
@@ -169,7 +192,7 @@ class GPTJModel(nn.Module):
self.embed_dim, self.embed_dim,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)]) [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
@@ -200,15 +223,20 @@ class GPTJModel(nn.Module):
class GPTJForCausalLM(nn.Module): class GPTJForCausalLM(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
self.transformer = GPTJModel(config) self.transformer = GPTJModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(
config.n_embd,
config.vocab_size, config.vocab_size,
gather_output=False, config.n_embd,
bias=True,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -226,43 +254,33 @@ class GPTJForCausalLM(nn.Module):
input_metadata, self.lm_head.bias) input_metadata, self.lm_head.bias)
return next_tokens return next_tokens
_column_parallel_weights = [
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
"lm_head.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() stacked_params_mapping = [
state_dict = self.state_dict() # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
is_attention_weight = False if weight_name not in name:
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[1] weight_loader = param.weight_loader
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * weight_loader(param, loaded_weight, shard_id)
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
load_tensor_parallel_weights(param, loaded_weight, name, weight_loader(param, loaded_weight)
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -29,14 +29,17 @@ from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -44,7 +47,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@@ -56,15 +63,16 @@ class GPTNeoXAttention(nn.Module):
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear( self.query_key_value = QKVParallelLinear(
config.hidden_size, config.hidden_size,
3 * config.hidden_size, self.head_size,
gather_output=False, self.total_num_heads,
linear_method=linear_method,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
@@ -100,19 +108,25 @@ class GPTNeoXAttention(nn.Module):
class GPTNeoXMLP(nn.Module): class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
gather_output=False, linear_method=linear_method,
) )
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.hidden_act) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)
@@ -123,15 +137,19 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config, linear_method)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config, linear_method)
def forward( def forward(
self, self,
@@ -169,7 +187,11 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(nn.Module): class GPTNeoXModel(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
@@ -177,8 +199,10 @@ class GPTNeoXModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) GPTNeoXLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@@ -210,15 +234,18 @@ class GPTNeoXModel(nn.Module):
class GPTNeoXForCausalLM(nn.Module): class GPTNeoXForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.linear_method = linear_method
self.embed_out = ColumnParallelLinear( self.gpt_neox = GPTNeoXModel(config, linear_method)
config.hidden_size, self.embed_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
bias=False, config.hidden_size,
gather_output=False,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -236,50 +263,35 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() params_dict = dict(self.named_parameters())
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue
param = state_dict[name] param = params_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
if "query_key_value" in name:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size if output_dim is not None:
head_size = hidden_size // num_heads loaded_weight_shape = loaded_weight.shape
if "query_key_value.weight" in name: loaded_weight = loaded_weight.view(
loaded_weight = loaded_weight.view(-1, 3, head_size, loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
hidden_size) loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(
loaded_weight = loaded_weight.reshape(-1, hidden_size) output_dim, output_dim + 1)
elif "query_key_value.bias" in name: loaded_weight = loaded_weight.reshape(loaded_weight_shape)
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) weight_loader = getattr(param, "weight_loader",
loaded_weight = loaded_weight.reshape(-1) default_weight_loader)
else: weight_loader(param, loaded_weight)
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -9,15 +9,17 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, from vllm.model_executor.weight_utils import (default_weight_loader,
RowParallelLinear, hf_model_weights_iterator)
VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -30,20 +32,17 @@ class InternLMMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = RowParallelLinear( hidden_size,
intermediate_size, bias=False,
hidden_size, linear_method=linear_method)
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -62,8 +61,10 @@ class InternLMAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
bias: bool,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -78,17 +79,18 @@ class InternLMAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
3 * self.total_num_heads * self.head_dim, self.head_dim,
bias=True, self.total_num_heads,
gather_output=False, bias=bias,
linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=True, bias=bias,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
@@ -117,7 +119,11 @@ class InternLMAttention(nn.Module):
class InternLMDecoderLayer(nn.Module): class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
@@ -126,13 +132,16 @@ class InternLMDecoderLayer(nn.Module):
self.self_attn = InternLMAttention( self.self_attn = InternLMAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.bias,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
) )
self.mlp = InternLMMLP( self.mlp = InternLMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@@ -146,10 +155,15 @@ class InternLMDecoderLayer(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -157,19 +171,21 @@ class InternLMDecoderLayer(nn.Module):
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event, cache_event=cache_event,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.post_attention_layernorm(
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states return hidden_states, residual
return hidden_states
class InternLMModel(nn.Module): class InternLMModel(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@@ -181,7 +197,7 @@ class InternLMModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternLMDecoderLayer(config) InternLMDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -195,36 +211,37 @@ class InternLMModel(nn.Module):
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
if cache_events is None: if cache_events is None:
cache_event = None cache_event = None
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event, cache_event,
residual,
) )
hidden_states = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class InternLMForCausalLM(nn.Module): class InternLMForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = InternLMModel(config) self.linear_method = linear_method
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.model = InternLMModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -241,69 +258,33 @@ class InternLMForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() stacked_params_mapping = [
state_dict = self.state_dict() # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 2 weight_loader = param.weight_loader
loaded_weight = loaded_weight[ weight_loader(param, loaded_weight, shard_id)
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
load_tensor_parallel_weights(param, loaded_weight, name, weight_loader(param, loaded_weight)
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -33,17 +33,19 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.model_executor.quantization_utils import QuantizationConfig hf_model_weights_iterator)
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -56,19 +58,17 @@ class LlamaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ParallelLinear.column(hidden_size, self.gate_up_proj = MergedColumnParallelLinear(
2 * intermediate_size, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = ParallelLinear.row(intermediate_size, hidden_size,
hidden_size, bias=False,
bias=False, linear_method=linear_method)
input_is_parallel=True,
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -91,7 +91,7 @@ class LlamaAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -109,7 +109,6 @@ class LlamaAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
@@ -117,21 +116,19 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ParallelLinear.column( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
(self.total_num_heads +
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
quant_config=quant_config,
) )
self.o_proj = ParallelLinear.row( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
@@ -165,11 +162,10 @@ class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
@@ -181,13 +177,13 @@ class LlamaDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, linear_method=linear_method,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@@ -201,10 +197,15 @@ class LlamaDecoderLayer(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -212,14 +213,12 @@ class LlamaDecoderLayer(nn.Module):
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event, cache_event=cache_event,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.post_attention_layernorm(
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states return hidden_states, residual
return hidden_states
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
@@ -227,20 +226,18 @@ class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config) LlamaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -254,20 +251,22 @@ class LlamaModel(nn.Module):
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
if cache_events is None: if cache_events is None:
cache_event = None cache_event = None
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event, cache_event,
residual,
) )
hidden_states = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@@ -276,19 +275,13 @@ class LlamaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.linear_method = linear_method
self.model = LlamaModel(config, quant_config) self.model = LlamaModel(config, linear_method)
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -305,118 +298,33 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
if self.quant_config is None: stacked_params_mapping = [
weight_suffixes = ["weight"] # (param_name, shard_name, shard_id)
else: ("qkv_proj", "q_proj", "q"),
weight_suffixes = self.quant_config.get_tp_tensor_names() ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
column_parallel_weights: List[str] = [] ("gate_up_proj", "gate_proj", 0),
for layer in self._column_parallel_layers: ("gate_up_proj", "up_proj", 1),
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
num_kv_heads_replicas = max(1,
tp_size // self.config.num_key_value_heads)
num_kv_heads_per_gpu = max(1,
self.config.num_key_value_heads // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
num_kv_heads_per_gpu)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
] ]
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
if is_transposed: weight_loader = param.weight_loader
param = param.T weight_loader(param, loaded_weight, shard_id)
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if weight_name in ["k_proj", "v_proj"]:
shard_id = tp_rank // num_kv_heads_replicas
else:
shard_id = tp_rank
loaded_weight = loaded_weight[shard_size *
shard_id:shard_size *
(shard_id + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
is_gate_up_weight = False default_weight_loader)
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): weight_loader(param, loaded_weight)
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights, tp_rank)

View File

@@ -33,17 +33,19 @@ from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.model_executor.quantization_utils import QuantizationConfig hf_model_weights_iterator)
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -56,19 +58,17 @@ class MistralMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ParallelLinear.column(hidden_size, self.gate_up_proj = MergedColumnParallelLinear(
2 * intermediate_size, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = ParallelLinear.row(intermediate_size, hidden_size,
hidden_size, bias=False,
bias=False, linear_method=linear_method)
input_is_parallel=True,
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -89,7 +89,7 @@ class MistralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -98,8 +98,15 @@ class MistralAttention(nn.Module):
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0 if self.total_num_kv_heads >= tp_size:
self.num_kv_heads = self.total_num_kv_heads // tp_size # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
@@ -107,20 +114,19 @@ class MistralAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.qkv_proj = ParallelLinear.column( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
quant_config=quant_config,
) )
self.o_proj = ParallelLinear.row( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim, self.head_dim,
@@ -153,7 +159,7 @@ class MistralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@@ -165,13 +171,13 @@ class MistralDecoderLayer(nn.Module):
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config, linear_method=linear_method,
sliding_window=config.sliding_window) sliding_window=config.sliding_window)
self.mlp = MistralMLP( self.mlp = MistralMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@@ -185,10 +191,15 @@ class MistralDecoderLayer(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -196,14 +207,12 @@ class MistralDecoderLayer(nn.Module):
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event, cache_event=cache_event,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.post_attention_layernorm(
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states return hidden_states, residual
return hidden_states
class MistralModel(nn.Module): class MistralModel(nn.Module):
@@ -211,20 +220,19 @@ class MistralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config) MistralDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -238,20 +246,22 @@ class MistralModel(nn.Module):
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
if cache_events is None: if cache_events is None:
cache_event = None cache_event = None
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event, cache_event,
residual,
) )
hidden_states = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@@ -260,19 +270,13 @@ class MistralForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.linear_method = linear_method
self.model = MistralModel(config, quant_config) self.model = MistralModel(config, linear_method)
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -289,112 +293,33 @@ class MistralForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
if self.quant_config is None: stacked_params_mapping = [
weight_suffixes = ["weight"] # (param_name, shard_name, shard_id)
else: ("qkv_proj", "q_proj", "q"),
weight_suffixes = self.quant_config.get_tp_tensor_names() ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
column_parallel_weights: List[str] = [] ("gate_up_proj", "gate_proj", 0),
for layer in self._column_parallel_layers: ("gate_up_proj", "up_proj", 1),
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
] ]
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
if is_transposed: weight_loader = param.weight_loader
param = param.T weight_loader(param, loaded_weight, shard_id)
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
is_gate_up_weight = False default_weight_loader)
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): weight_loader(param, loaded_weight)
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -9,15 +9,17 @@ import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, from vllm.model_executor.layers.vocab_parallel_embedding import (
hf_model_weights_iterator, VocabParallelEmbedding)
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
@@ -39,7 +41,11 @@ def _get_alibi_slopes(
class MPTAttention(nn.Module): class MPTAttention(nn.Module):
def __init__(self, config: MPTConfig): def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.total_num_heads = config.n_heads self.total_num_heads = config.n_heads
@@ -49,11 +55,13 @@ class MPTAttention(nn.Module):
assert not config.attn_config["prefix_lm"] assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"] assert config.attn_config["alibi"]
self.qkv_proj = ColumnParallelLinear( # pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
self.d_model, self.d_model,
3 * self.d_model, self.d_model // self.total_num_heads,
self.total_num_heads,
bias=not config.no_bias, bias=not config.no_bias,
gather_output=False, linear_method=linear_method,
) )
if self.qk_ln: if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model) self.q_ln = nn.LayerNorm(self.d_model)
@@ -62,7 +70,7 @@ class MPTAttention(nn.Module):
self.d_model, self.d_model,
self.d_model, self.d_model,
bias=not config.no_bias, bias=not config.no_bias,
input_is_parallel=True, linear_method=linear_method,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
@@ -91,7 +99,7 @@ class MPTAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # unused. del position_ids # unused.
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
@@ -107,7 +115,11 @@ class MPTAttention(nn.Module):
class MPTMLP(nn.Module): class MPTMLP(nn.Module):
def __init__(self, config: MPTConfig): def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
expansion_ratio = config.expansion_ratio expansion_ratio = config.expansion_ratio
@@ -116,14 +128,15 @@ class MPTMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=not config.no_bias, bias=not config.no_bias,
gather_output=False, linear_method=linear_method,
) )
self.act = get_act_fn("gelu") quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=not config.no_bias, bias=not config.no_bias,
input_is_parallel=True, linear_method=linear_method,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -135,13 +148,17 @@ class MPTMLP(nn.Module):
class MPTBlock(nn.Module): class MPTBlock(nn.Module):
def __init__(self, config: MPTConfig): def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size) self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config) self.attn = MPTAttention(config, linear_method)
self.norm_2 = nn.LayerNorm(hidden_size) self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config) self.ffn = MPTMLP(config, linear_method)
def forward( def forward(
self, self,
@@ -168,7 +185,11 @@ class MPTBlock(nn.Module):
class MPTModel(nn.Module): class MPTModel(nn.Module):
def __init__(self, config: MPTConfig): def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
assert config.embedding_fraction == 1.0 assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm" assert config.norm_type == "low_precision_layernorm"
@@ -178,7 +199,7 @@ class MPTModel(nn.Module):
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)]) [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:
for module in self.modules(): for module in self.modules():
@@ -215,14 +236,17 @@ class MPTModel(nn.Module):
class MPTForCausalLM(nn.Module): class MPTForCausalLM(nn.Module):
def __init__(self, config: MPTConfig): def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.linear_method = linear_method
self.transformer = MPTModel(config) self.transformer = MPTModel(config, linear_method)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -240,45 +264,15 @@ class MPTForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() params_dict = dict(self.named_parameters(remove_duplicate=False))
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "Wqkv" in name: param = params_dict[name]
# NOTE(woosuk): MPT's fused QKV has the shape of weight_loader = getattr(param, "weight_loader",
# [3 * num_heads * head_size, hidden_size]. default_weight_loader)
# When tensor model parallelism is used, we need to shard weight_loader(param, loaded_weight)
# the weight along the hidden dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
name = name.replace("Wqkv", "qkv_proj")
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@@ -30,14 +30,18 @@ from transformers import OPTConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -63,6 +67,7 @@ class OPTAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
bias: bool = True, bias: bool = True,
linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@@ -74,17 +79,18 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
embed_dim, embed_dim,
3 * embed_dim, self.head_dim,
total_num_heads,
bias=bias, bias=bias,
gather_output=False, linear_method=linear_method,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
embed_dim, embed_dim,
embed_dim, embed_dim,
bias=bias, bias=bias,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
@@ -108,7 +114,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module): class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig): def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@@ -116,9 +126,12 @@ class OPTDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.enable_bias, bias=config.enable_bias,
linear_method=linear_method,
) )
self.do_layer_norm_before = config.do_layer_norm_before self.do_layer_norm_before = config.do_layer_norm_before
self.activation_fn = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, self.embed_dim,
@@ -127,13 +140,13 @@ class OPTDecoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
config.ffn_dim, config.ffn_dim,
bias=config.enable_bias, bias=config.enable_bias,
gather_output=False, linear_method=linear_method,
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.ffn_dim, config.ffn_dim,
self.embed_dim, self.embed_dim,
bias=config.enable_bias, bias=config.enable_bias,
input_is_parallel=True, linear_method=linear_method,
) )
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, self.embed_dim,
@@ -177,7 +190,11 @@ class OPTDecoderLayer(nn.Module):
class OPTDecoder(nn.Module): class OPTDecoder(nn.Module):
def __init__(self, config: OPTConfig): def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@@ -194,16 +211,18 @@ class OPTDecoder(nn.Module):
# Project out & in will be replicated if they exist. # Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim, config.word_embed_proj_dim,
bias=False) bias=False,
linear_method=linear_method)
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size, config.hidden_size,
bias=False) bias=False,
linear_method=linear_method)
else: else:
self.project_in = None self.project_in = None
@@ -218,8 +237,10 @@ class OPTDecoder(nn.Module):
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) OPTDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
def forward( def forward(
self, self,
@@ -232,7 +253,7 @@ class OPTDecoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
if self.project_in is not None: if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds) inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)): for i in range(len(self.layers)):
@@ -247,15 +268,19 @@ class OPTDecoder(nn.Module):
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states, _ = self.project_out(hidden_states)
return hidden_states return hidden_states
class OPTModel(nn.Module): class OPTModel(nn.Module):
def __init__(self, config: OPTConfig): def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.decoder = OPTDecoder(config) self.decoder = OPTDecoder(config, linear_method)
def forward( def forward(
self, self,
@@ -271,12 +296,15 @@ class OPTModel(nn.Module):
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = OPTModel(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.model = OPTModel(config, linear_method)
# parallelism
self.lm_head_weight = self.model.decoder.embed_tokens.weight self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@@ -294,48 +322,31 @@ class OPTForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() stacked_params_mapping = [
state_dict = self.state_dict() # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if name.startswith("decoder."): if weight_name not in name:
name = "model." + name
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 3 weight_loader = param.weight_loader
loaded_weight = loaded_weight[ weight_loader(param, loaded_weight, shard_id)
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
load_tensor_parallel_weights(param, loaded_weight, name, weight_loader(param, loaded_weight)
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -0,0 +1,316 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class PhiEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
def forward(self, input_ids: torch.LongTensor):
return self.wte(input_ids)
class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
# pylint: disable=C0103
self.Wqkv = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
linear_method=linear_method,
)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_size,
self.total_num_heads,
bias=False,
linear_method=linear_method,
)
self.out_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
linear_method=linear_method,
)
scaling = self.head_size**-0.5
rotary_dim = config.rotary_dim
assert rotary_dim % 2 == 0
# pylint: disable=C0301
# Refer to:
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = 10000
max_position_embeddings = getattr(config, "n_positions", 2048)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_size,
scaling,
rotary_dim,
base=rope_theta,
max_position=max_position_embeddings)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
class PhiMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
n_inner = getattr(config, "n_inner", None)
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
self.fc1 = ColumnParallelLinear(
config.hidden_size,
n_inner,
linear_method=linear_method,
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
linear_method=linear_method,
)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
n_inner)
def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.mixer = PhiAttention(config, linear_method)
self.mlp = PhiMLP(config, linear_method)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
class PhiCausalLMHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
):
hidden_states = self.ln(hidden_states)
next_tokens = self.sampler(self.linear.weight, hidden_states,
input_metadata, self.linear.bias)
return next_tokens
class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.embd = PhiEmbedding(config)
self.h = nn.ModuleList([
PhiLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.embd(input_ids)
for i in range(self.config.num_hidden_layers):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
return hidden_states
class PhiForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = PhiModel(config, linear_method)
self.lm_head = PhiCausalLMHead(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
lm_logits = self.lm_head(hidden_states, input_metadata)
return lm_logits
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
# pylint: disable=E1136
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -15,24 +15,19 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, VocabParallelEmbedding, ParallelLMHead)
hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size, from vllm.model_executor.weight_utils import (default_weight_loader,
) hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.layers import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.transformers_utils.configs.qwen import QWenConfig
@@ -46,20 +41,17 @@ class QWenMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str = "silu", hidden_act: str = "silu",
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.c_proj = RowParallelLinear(intermediate_size,
self.c_proj = RowParallelLinear( hidden_size,
intermediate_size, bias=False,
hidden_size, linear_method=linear_method)
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -74,12 +66,15 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__(self, def __init__(
hidden_size: int, self,
num_heads: int, hidden_size: int,
max_position_embeddings: int, num_heads: int,
rope_theta: float = 10000, max_position_embeddings: int,
rope_scaling: Optional[Dict[str, Any]] = None): rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
@@ -90,18 +85,18 @@ class QWenAttention(nn.Module):
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
# pylint: disable=invalid-name self.c_attn = QKVParallelLinear(
self.c_attn = ColumnParallelLinear(
hidden_size, hidden_size,
3 * hidden_size, self.head_dim,
self.total_num_heads,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
@@ -134,7 +129,11 @@ class QWenAttention(nn.Module):
class QWenBlock(nn.Module): class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig): def __init__(
self,
config: QWenConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -144,11 +143,14 @@ class QWenBlock(nn.Module):
config.num_attention_heads, config.num_attention_heads,
config.max_position_embeddings, config.max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling) rope_scaling=rope_scaling,
linear_method=linear_method)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2) self.mlp = QWenMLP(config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method)
def forward( def forward(
self, self,
@@ -157,10 +159,14 @@ class QWenBlock(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states if residual is None:
hidden_states = self.ln_1(hidden_states) residual = hidden_states
hidden_states = self.ln_1(hidden_states)
else:
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -168,30 +174,32 @@ class QWenBlock(nn.Module):
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event, cache_event=cache_event,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.ln_2(hidden_states, residual)
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states return hidden_states, residual
return hidden_states
class QWenModel(nn.Module): class QWenModel(nn.Module):
def __init__(self, config: QWenConfig): def __init__(
self,
config: QWenConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList([
[QWenBlock(config) for _ in range(config.num_hidden_layers)]) QWenBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward( def forward(
@@ -203,36 +211,37 @@ class QWenModel(nn.Module):
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
residual = None
for i in range(len(self.h)): for i in range(len(self.h)):
if cache_events is None: if cache_events is None:
cache_event = None cache_event = None
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event, cache_event,
residual,
) )
hidden_states = self.ln_f(hidden_states) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states return hidden_states
class QWenLMHeadModel(nn.Module): class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig): def __init__(
self,
config: QWenConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = QWenModel(config) self.linear_method = linear_method
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.transformer = QWenModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@@ -249,75 +258,30 @@ class QWenLMHeadModel(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [] def load_weights(self,
_row_parallel_weights = ["c_proj.weight"] model_name_or_path: str,
cache_dir: Optional[str] = None,
def load_weights( load_format: str = "auto",
self, revision: Optional[str] = None):
model_name_or_path: str, stacked_params_mapping = [
cache_dir: Optional[str] = None, # (param_name, shard_name, shard_id)
load_format: str = "auto", ("gate_up_proj", "w2", 0),
revision: Optional[str] = None, ("gate_up_proj", "w1", 1),
): ]
tp_world_size = get_tensor_model_parallel_world_size() params_dict = dict(self.named_parameters())
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if "weight" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "bias" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["w2", "w1"]):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 2 weight_loader = param.weight_loader
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * weight_loader(param, loaded_weight, shard_id)
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
weight_loader(param, loaded_weight)
if "wte" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)

View File

@@ -0,0 +1,326 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.transformers_utils.configs.yi import YiConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class YiMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class YiAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class YiDecoderLayer(nn.Module):
def __init__(
self,
config: YiConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = YiAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
self.mlp = YiMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.ln1(hidden_states)
else:
hidden_states, residual = self.ln1(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
hidden_states, residual = self.ln2(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class YiModel(nn.Module):
def __init__(
self,
config: YiConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
YiDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class YiForCausalLM(nn.Module):
def __init__(
self,
config: YiConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = YiModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -1,303 +0,0 @@
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.parallel_utils.utils import (
divide,
VocabUtility,
split_tensor_along_last_dim,
)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# TODO: Handle vocab padding here.
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.tp_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Parameters.
# NOTE: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.create_weights(params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(
torch.empty(self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None
input_parallel = input_
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments:
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError('When not reduce the results, adding bias to the '
'results can lead to incorrect results')
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(
torch.empty(self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias

View File

@@ -2,7 +2,7 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Sequence from typing import Sequence
import torch import torch
@@ -24,7 +24,7 @@ def split_tensor_along_last_dim(
tensor: torch.Tensor, tensor: torch.Tensor,
num_partitions: int, num_partitions: int,
contiguous_split_chunks: bool = False, contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]: ) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension. """ Split a tensor along its last dimension.
Arguments: Arguments:
@@ -46,25 +46,3 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list return tensor_list
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank)

View File

@@ -1,20 +0,0 @@
from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
}
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quant_class",
]

View File

@@ -1,72 +0,0 @@
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
@classmethod
def get_name(cls) -> str:
return "awq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros"]
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]

View File

@@ -1,5 +1,6 @@
"""Utils for model executor.""" """Utils for model executor."""
import random import random
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
@@ -11,3 +12,24 @@ def set_random_seed(seed: int) -> None:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def set_weight_attrs(
weight: torch.Tensor,
weight_attrs: Optional[Dict[str, Any]],
):
"""Set attributes on a weight tensor.
This method is used to set attributes on a weight tensor. This method
will not overwrite existing attributes.
Args:
weight: The weight tensor.
weight_attrs: A dictionary of attributes to set on the weight tensor.
"""
if weight_attrs is None:
return
for key, value in weight_attrs.items():
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)

View File

@@ -7,14 +7,15 @@ from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open
import numpy as np import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch import torch
from transformers import PretrainedConfig
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.quantization_utils import get_quant_class from vllm.model_executor.layers.quantization import (get_quantization_config,
from vllm.model_executor.quantization_utils.base import QuantizationConfig QuantizationConfig)
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file(
def get_quant_config( def get_quant_config(
quantization: str, quantization: str,
model_name_or_path: str, model_name_or_path: str,
hf_config: PretrainedConfig,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
) -> QuantizationConfig: ) -> QuantizationConfig:
quant_cls = get_quantization_config(quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(hf_config, "quantization_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
if not is_local: if not is_local:
# Download the config files. # Download the config files.
@@ -98,7 +106,6 @@ def get_quant_config(
hf_folder = model_name_or_path hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json")) config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quant_class(quantization)
quant_config_files = [ quant_config_files = [
f for f in config_files if any( f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames()) f.endswith(x) for x in quant_cls.get_config_filenames())
@@ -236,7 +243,7 @@ def hf_model_weights_iterator(
for st_file in hf_weights_files: for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
for name in f.keys(): for name in f.keys():
param = f.get_slice(name) param = f.get_tensor(name)
yield name, param yield name, param
else: else:
for bin_file in hf_weights_files: for bin_file in hf_weights_files:
@@ -262,46 +269,10 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
return x return x
def load_padded_tensor_parallel_vocab( def default_weight_loader(param: torch.Tensor,
param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` """Default weight loader."""
tensor_model_parallel_rank: int, assert param.size() == loaded_weight.size()
) -> None:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
param[:loaded_weight.shape[0]].copy_(loaded_weight)
def load_tensor_parallel_weights(
param: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
tensor_model_parallel_rank: int,
) -> None:
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[:, start_idx:end_idx]
break
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)

View File

@@ -53,6 +53,7 @@ class RequestOutput:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string of the request. prompt: The prompt string of the request.
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
""" """

2
vllm/py.typed Normal file
View File

@@ -0,0 +1,2 @@
# Marker file for PEP 561.
# The vllm package uses inline types.

View File

@@ -1,7 +1,8 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import torch
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@@ -12,6 +13,12 @@ class SamplingType(IntEnum):
BEAM = 2 BEAM = 2
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
@@ -34,6 +41,10 @@ class SamplingParams:
frequency in the generated text so far. Values > 0 encourage the frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to model to use new tokens, while values < 0 encourage the model to
repeat tokens. repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1 encourage the
model to use new tokens, while values < 1 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling. the model more random. Zero means greedy sampling.
@@ -41,6 +52,9 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens. to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens. to -1 to consider all tokens.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length. length_penalty: Float that penalizes sequences based on their length.
Used in beam search. Used in beam search.
@@ -67,6 +81,10 @@ class SamplingParams:
`logprobs+1` elements in the response. `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token. prompt_logprobs: Number of log probabilities to return per prompt token.
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
""" """
def __init__( def __init__(
@@ -75,9 +93,11 @@ class SamplingParams:
best_of: Optional[int] = None, best_of: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: int = 0.0,
use_beam_search: bool = False, use_beam_search: bool = False,
length_penalty: float = 1.0, length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
@@ -88,14 +108,18 @@ class SamplingParams:
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
@@ -114,7 +138,8 @@ class SamplingParams:
self.logprobs = logprobs self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
@@ -136,6 +161,9 @@ class SamplingParams:
if not -2.0 <= self.frequency_penalty <= 2.0: if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got " raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.") f"{self.frequency_penalty}.")
if not 0.0 < self.repetition_penalty <= 2.0:
raise ValueError("repetition_penalty must be in (0, 2], got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0: if self.temperature < 0.0:
raise ValueError( raise ValueError(
f"temperature must be non-negative, got {self.temperature}.") f"temperature must be non-negative, got {self.temperature}.")
@@ -144,6 +172,9 @@ class SamplingParams:
if self.top_k < -1 or self.top_k == 0: if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, " raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {self.top_k}.") f"got {self.top_k}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
if self.max_tokens < 1: if self.max_tokens < 1:
raise ValueError( raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.") f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -201,9 +232,11 @@ class SamplingParams:
f"best_of={self.best_of}, " f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, " f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, " f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, " f"temperature={self.temperature}, "
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, " f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, " f"early_stopping={self.early_stopping}, "
@@ -212,4 +245,6 @@ class SamplingParams:
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, " f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})") f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")

View File

@@ -401,6 +401,12 @@ class SequenceGroupOutputs:
return (f"SequenceGroupOutputs(samples={self.samples}, " return (f"SequenceGroupOutputs(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})") f"prompt_logprobs={self.prompt_logprobs})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutputs):
raise NotImplementedError()
return (self.samples == other.samples
and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutputs object, # For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token. # each of which contains one possible candidate for the next token.

View File

@@ -5,12 +5,14 @@ from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = { _CONFIG_REGISTRY = {
"mpt": MPTConfig,
"baichuan": BaiChuanConfig,
"aquila": AquilaConfig, "aquila": AquilaConfig,
"baichuan": BaiChuanConfig,
"chatglm": ChatGLMConfig,
"mpt": MPTConfig,
"qwen": QWenConfig, "qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"yi": YiConfig,
} }

View File

@@ -1,16 +1,20 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.aquila import AquilaConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.yi import YiConfig
__all__ = [ __all__ = [
"MPTConfig",
"BaiChuanConfig",
"AquilaConfig", "AquilaConfig",
"BaiChuanConfig",
"ChatGLMConfig",
"MPTConfig",
"QWenConfig", "QWenConfig",
"RWConfig", "RWConfig",
"YiConfig",
] ]

View File

@@ -0,0 +1,68 @@
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
attribute_map = {
"num_hidden_layers": "num_layers",
"n_head_kv": "multi_query_group_num",
}
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
interleaved_qkv=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm)
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
self.interleaved_qkv = interleaved_qkv
super().__init__(**kwargs)

View File

@@ -1,52 +1,124 @@
# Adapted from # coding=utf-8
# Copied from
# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py # https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
"""A HuggingFace-style model configuration."""
import warnings
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from transformers import PretrainedConfig from transformers import PretrainedConfig
_ATTN_CONFIG_DEFAULTS = { attn_config_defaults: Dict = {
"attn_type": "multihead_attention", 'attn_type': 'multihead_attention',
"attn_pdrop": 0.0, 'attn_pdrop': 0.0,
"attn_impl": "triton", 'attn_impl': 'triton',
"qk_ln": False, 'qk_ln': False,
"clip_qkv": None, 'clip_qkv': None,
"softmax_scale": None, 'softmax_scale': None,
"prefix_lm": False, 'prefix_lm': False,
"attn_uses_sequence_id": False, 'attn_uses_sequence_id': False,
"alibi": False, 'alibi': False,
"alibi_bias_max": 8, 'alibi_bias_max': 8
}
ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
init_config_defaults: Dict = {
'name': 'kaiming_normal_',
'fan_mode': 'fan_in',
'init_nonlinearity': 'relu',
'init_div_is_residual': True,
'emb_init_std': None,
'emb_init_uniform_lim': None,
'init_std': None,
'init_gain': 0.0
} }
class MPTConfig(PretrainedConfig): class MPTConfig(PretrainedConfig):
model_type = "mpt" model_type = 'mpt'
attribute_map = { attribute_map = {
"hidden_size": "d_model", 'num_attention_heads': 'n_heads',
"num_attention_heads": "n_heads", 'hidden_size': 'd_model',
"num_hidden_layers": "n_layers", 'num_hidden_layers': 'n_layers',
} }
def __init__( # pylint: disable=dangerous-default-value
self, def __init__(self,
d_model: int = 2048, d_model: int = 2048,
n_heads: int = 16, n_heads: int = 16,
n_layers: int = 24, n_layers: int = 24,
expansion_ratio: int = 4, expansion_ratio: int = 4,
max_seq_len: int = 2048, max_seq_len: int = 2048,
vocab_size: int = 50368, vocab_size: int = 50368,
resid_pdrop: float = 0.0, resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0, emb_pdrop: float = 0.0,
learned_pos_emb: bool = True, learned_pos_emb: bool = True,
attn_config: Optional[Dict[str, Any]] = None, attn_config: Dict = attn_config_defaults,
init_device: str = "cpu", ffn_config: Dict = ffn_config_defaults,
logit_scale: Optional[Union[float, str]] = None, init_device: str = 'cpu',
no_bias: bool = False, logit_scale: Optional[Union[float, str]] = None,
verbose: int = 0, no_bias: bool = False,
embedding_fraction: float = 1.0, embedding_fraction: float = 1.0,
norm_type: str = "low_precision_layernorm", norm_type: str = 'low_precision_layernorm',
use_cache: bool = False, use_cache: bool = False,
**kwargs, init_config: Dict = init_config_defaults,
) -> None: fc_type: str = 'torch',
verbose: Optional[int] = None,
**kwargs: Any):
# pylint: disable=line-too-long
"""The MPT configuration class.
Args:
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
expansion_ratio (int): The ratio of the up/down scale in the ffn.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
emb_pdrop (float): The dropout probability for the embedding layer.
learned_pos_emb (bool): Whether to use learned positional embeddings
attn_config (Dict): A dictionary used to configure the model's attention module:
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
attn_pdrop (float): The dropout probability for the attention layers.
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): The verbosity level. 0 is silent.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
use_cache (bool): Whether or not the model should return the last key/values attentions
init_config (Dict): A dictionary used to configure the model initialization:
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
init_std (float): The standard deviation of the normal distribution used to initialize the model,
if using the baseline_ parameter initialization scheme.
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
"""
self.d_model = d_model self.d_model = d_model
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
@@ -56,19 +128,105 @@ class MPTConfig(PretrainedConfig):
self.resid_pdrop = resid_pdrop self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb self.learned_pos_emb = learned_pos_emb
if attn_config is None: self.attn_config = attn_config
self.attn_config = _ATTN_CONFIG_DEFAULTS self.ffn_config = ffn_config
else:
self.attn_config = attn_config
self.init_device = init_device self.init_device = init_device
self.logit_scale = logit_scale self.logit_scale = logit_scale
self.no_bias = no_bias self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction self.embedding_fraction = embedding_fraction
self.norm_type = norm_type self.norm_type = norm_type
self.use_cache = use_cache self.use_cache = use_cache
if "name" in kwargs: self.init_config = init_config
del kwargs["name"] self.fc_type = fc_type
if "loss_fn" in kwargs: if verbose is not None:
del kwargs["loss_fn"] warnings.warn(
DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
))
if 'name' in kwargs:
del kwargs['name']
if 'loss_fn' in kwargs:
del kwargs['loss_fn']
if self.attn_config.get('alibi', False):
self.learned_pos_emb = False
warnings.warn(
f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`'
)
super().__init__(**kwargs) super().__init__(**kwargs)
self._validate_config()
def _set_config_defaults(
self, config: Dict[str, Any],
config_defaults: Dict[str, Any]) -> Dict[str, Any]:
for (k, v) in config_defaults.items():
if k not in config:
config[k] = v
return config
def _validate_config(self) -> None:
self.attn_config = self._set_config_defaults(self.attn_config,
attn_config_defaults)
self.ffn_config = self._set_config_defaults(self.ffn_config,
ffn_config_defaults)
self.init_config = self._set_config_defaults(self.init_config,
init_config_defaults)
if self.d_model % self.n_heads != 0:
raise ValueError('d_model must be divisible by n_heads')
if any((
prob < 0 or prob > 1 for prob in
[self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]
)):
raise ValueError(
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" # pylint: disable=line-too-long
)
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
raise ValueError(
f"Unknown attn_impl={self.attn_config['attn_impl']}")
if self.attn_config['prefix_lm'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
'prefix_lm only implemented with torch and triton attention.')
if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [
'torch', 'triton'
]:
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.' # pylint: disable=line-too-long
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' # pylint: disable=line-too-long
)
if isinstance(self.logit_scale,
str) and self.logit_scale != 'inv_sqrt_d_model':
raise ValueError(
f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." # pylint: disable=line-too-long
)
if self.init_config.get('name', None) is None:
raise ValueError(
f"self.init_config={self.init_config!r} 'name' needs to be set."
)
if not self.learned_pos_emb and (not self.attn_config['alibi']):
warnings.warn(
'Positional information not being provided to the model.')
if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
try:
# pylint: disable=import-outside-toplevel
import transformer_engine.pytorch as te
del te
except Exception as exc:
raise ImportError(
# pylint: disable=line-too-long
'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. '
+
'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n'
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
) from exc
if self.ffn_config['ffn_type'] == 'mptmlp':
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias

View File

@@ -0,0 +1,64 @@
""" Yi model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
Yi_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class YiConfig(PretrainedConfig):
r"""
Reference:
https://huggingface.co/01-ai/Yi-6B/blob/main/configuration_yi.py
"""
model_type = "Yi"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
output_attentions=False,
rope_theta=5000000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.output_attentions = output_attentions
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str], output_tokens: List[str],
skip_special_tokens: bool, skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str: ) -> str:
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
@@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
if current_sub_text: if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text) sub_texts.append(sub_text)
return " ".join(sub_texts) if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# Based on # Based on
@@ -109,6 +113,7 @@ def detokenize_incrementally(
prefix_offset: int = 0, prefix_offset: int = 0,
read_offset: int = 0, read_offset: int = 0,
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]: ) -> Tuple[List[str], str, int, int]:
new_token_id = all_input_ids[-1] new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence # This is the first iteration for this sequence
@@ -120,7 +125,11 @@ def detokenize_incrementally(
# tokenizers (bigger = more conservative). # tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token. # Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0) prefix_offset = max(len(output_tokens) - 6, 0)
read_offset = max(len(output_tokens) - 1, 0) # If the first new token is a special token, we can't skip 1 extra token
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
read_offset = max(len(output_tokens), 0)
else:
read_offset = max(len(output_tokens) - 1, 0)
else: else:
# Put new_token_id in a list so skip_special_tokens is respected # Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens( new_tokens = tokenizer.convert_ids_to_tokens(
@@ -139,11 +148,15 @@ def detokenize_incrementally(
prefix_text = _convert_tokens_to_string_with_added_encoders( prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer, tokenizer,
output_tokens[prefix_offset:read_offset], output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders( new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer, tokenizer,
output_tokens[prefix_offset:], output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"): if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence # utf-8 char at the end means it's a potential unfinished byte sequence

View File

@@ -10,10 +10,10 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel) initialize_model_parallel)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes from vllm.utils import get_gpu_memory
class Worker: class Worker:
@@ -141,13 +141,6 @@ class Worker:
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window self.sliding_window = cache_config.sliding_window
if self.sliding_window is None:
max_seq_len = self.scheduler_config.max_model_len
else:
max_seq_len = min(self.scheduler_config.max_model_len,
self.sliding_window)
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config) self.parallel_config)
self.cache_events = self.cache_engine.events self.cache_events = self.cache_engine.events
@@ -158,9 +151,13 @@ class Worker:
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[int] = [] input_tokens: List[List[int]] = []
input_positions: List[int] = [] input_positions: List[List[int]] = []
slot_mapping: List[int] = [] slot_mapping: List[List[int]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
# Add prompt tokens. # Add prompt tokens.
prompt_lens: List[int] = [] prompt_lens: List[int] = []
@@ -180,48 +177,82 @@ class Worker:
prompt_len = len(prompt_tokens) prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens) if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens))) input_positions.append(list(range(prompt_len)))
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([0] * prompt_len) slot_mapping.append([0] * prompt_len)
continue continue
# Compute the slot mapping. # Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len): for i in range(prompt_len):
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping[-1].append(slot)
# Add generation tokens. # Add generation tokens.
max_context_len = 0 max_context_len = 0
max_num_blocks_per_seq = 0 max_num_blocks_per_seq = 0
context_lens: List[int] = [] context_lens: List[int] = []
generation_block_tables: List[List[int]] = [] generation_block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list: max_seq_len = max(prompt_lens) if prompt_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.is_prompt: if seq_group_metadata.is_prompt:
# We need to do this in this loop as we need to know max_seq_len
assert len(
seq_ids) == 1, "Prompt input should have only one seq."
sampling_params = seq_group_metadata.sampling_params
assert len(prompt_lens) == len(seq_group_metadata_list)
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
continue continue
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params)) seq_groups.append((seq_ids, sampling_params))
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token) input_tokens.append([generation_token])
context_len = seq_data.get_len() context_len = seq_data.get_len()
position = context_len - 1 position = context_len - 1
if self.sliding_window is not None: if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window) context_len = min(context_len, self.sliding_window)
input_positions.append(position) input_positions.append([position])
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
@@ -233,7 +264,7 @@ class Worker:
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append([slot])
if self.sliding_window is not None: if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window // sliding_window_blocks = (self.sliding_window //
@@ -241,28 +272,42 @@ class Worker:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
# Optimization: Pad the input length to be a multiple of 8. padded_input_tokens = [
# This is required for utilizing the Tensor Cores in NVIDIA GPUs. _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) ]
input_positions = _pad_to_alignment(input_positions, multiple_of=8) padded_input_positions = [
_pad_to_max(positions, max_seq_len, pad=0)
for positions in input_positions
]
padded_slot_mapping = [
_pad_to_max(mapping, max_seq_len, pad=-1)
for mapping in slot_mapping
]
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
for block_table in generation_block_tables
]
# Convert to tensors. # Convert to tensors.
tokens_tensor = torch.tensor(input_tokens, tokens_tensor = torch.tensor(padded_input_tokens,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
positions_tensor = torch.tensor(input_positions, positions_tensor = torch.tensor(padded_input_positions,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
slot_mapping_tensor = torch.tensor(slot_mapping, slot_mapping_tensor = torch.tensor(padded_slot_mapping,
dtype=torch.int, dtype=torch.long,
device="cuda") device="cuda")
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device="cuda") device="cuda")
padded_block_tables = [ selected_token_indices = torch.tensor(selected_token_indices,
_pad_to_max(block_table, max_num_blocks_per_seq) dtype=torch.long,
for block_table in generation_block_tables device="cuda")
] categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
block_tables_tensor = torch.tensor(padded_block_tables, block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int, dtype=torch.int,
device="cuda") device="cuda")
@@ -279,6 +324,8 @@ class Worker:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
max_context_len=max_context_len, max_context_len=max_context_len,
block_tables=block_tables_tensor, block_tables=block_tables_tensor,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
) )
return tokens_tensor, positions_tensor, input_metadata return tokens_tensor, positions_tensor, input_metadata
@@ -361,32 +408,12 @@ def _init_distributed_environment(
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of) return x + [pad] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [0] * (max_len - len(x)) return x + [pad] * (max_len - len(x))
def _check_if_can_support_max_seq_len(max_seq_len: int,
block_size: int) -> None:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_memory_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
padded_max_seq_len = (
(max_seq_len + block_size - 1) / block_size) * block_size
# padded_max_seq_len + extra buffer
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
if padded_max_seq_len * float32_bytes > max_shared_mem:
raise RuntimeError(
f"vLLM cannot currently support max_model_len={max_seq_len} "
f"with block_size={block_size} on GPU with compute "
f"capability {torch.cuda.get_device_capability()} "
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):