Compare commits
146 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e0b6e7757 | ||
|
|
941767127c | ||
|
|
74d8d77626 | ||
|
|
fd4ea8ef5c | ||
|
|
1066cbd152 | ||
|
|
6ef00b03a2 | ||
|
|
9140561059 | ||
|
|
77af974b40 | ||
|
|
4934d49274 | ||
|
|
358c328d69 | ||
|
|
4aaafdd289 | ||
|
|
66b108d142 | ||
|
|
e0ff920001 | ||
|
|
face83c7ec | ||
|
|
1db83e31a2 | ||
|
|
a1b9cb2a34 | ||
|
|
3a4fd5ca59 | ||
|
|
c17daa9f89 | ||
|
|
bd29cf3d3a | ||
|
|
31bff69151 | ||
|
|
ba4f826738 | ||
|
|
de60a3fb93 | ||
|
|
21d5daa4ac | ||
|
|
290e015c6c | ||
|
|
1b7c791d60 | ||
|
|
bbe4466fd9 | ||
|
|
08133c4d1a | ||
|
|
76a7983b23 | ||
|
|
8041b7305e | ||
|
|
3ec8c25cd0 | ||
|
|
671af2b1c0 | ||
|
|
6f41f0e377 | ||
|
|
2c9b638065 | ||
|
|
a7347d9a6d | ||
|
|
f8c688d746 | ||
|
|
c9fadda543 | ||
|
|
30fb0956df | ||
|
|
3a765bd5e1 | ||
|
|
26c52a5ea6 | ||
|
|
c3372e87be | ||
|
|
b0a1d667b0 | ||
|
|
e1d5402238 | ||
|
|
3d1cfbfc74 | ||
|
|
37ca558103 | ||
|
|
eed74a558f | ||
|
|
2acd76f346 | ||
|
|
b81a6a6bb3 | ||
|
|
0fbfc4b81b | ||
|
|
c06170cc8e | ||
|
|
614856da25 | ||
|
|
05bdf4eaf3 | ||
|
|
6774bd50b0 | ||
|
|
31c1f3255e | ||
|
|
21d93c140d | ||
|
|
f1c8520146 | ||
|
|
096827c284 | ||
|
|
6565d9e33e | ||
|
|
f375ec8440 | ||
|
|
518369d78c | ||
|
|
30bad5c492 | ||
|
|
3fefe271ec | ||
|
|
6428f1d051 | ||
|
|
7e1b21daac | ||
|
|
cb3f30c600 | ||
|
|
f3e024bece | ||
|
|
31d2ab4aff | ||
|
|
eb17212858 | ||
|
|
4dd4b5c538 | ||
|
|
6120e5aaea | ||
|
|
2eaa81b236 | ||
|
|
81ce2a4b26 | ||
|
|
5dd80d3777 | ||
|
|
beeee69bc9 | ||
|
|
9bf28d0b69 | ||
|
|
c0ce15dfb2 | ||
|
|
b9bcdc7158 | ||
|
|
4ff0203987 | ||
|
|
b5f882cc98 | ||
|
|
2e8fc0d4c3 | ||
|
|
dacaf5a400 | ||
|
|
24cde76a15 | ||
|
|
1aa1361510 | ||
|
|
fe470ae5ad | ||
|
|
3a8c2381f7 | ||
|
|
c85b80c2b6 | ||
|
|
2b981012a6 | ||
|
|
6ccc0bfffb | ||
|
|
c8e7eb1eb3 | ||
|
|
24f60a54f4 | ||
|
|
42c02f5892 | ||
|
|
ebede26ebf | ||
|
|
d940ce497e | ||
|
|
05ff90b692 | ||
|
|
1d9b737e05 | ||
|
|
60dc62dc9e | ||
|
|
0f90effc66 | ||
|
|
464dd985e3 | ||
|
|
c07a442854 | ||
|
|
cd3aa153a4 | ||
|
|
9b294976a2 | ||
|
|
5313c2cb8b | ||
|
|
5f09cbdb63 | ||
|
|
4cefa9b49b | ||
|
|
f86bd6190a | ||
|
|
e5452ddfd6 | ||
|
|
d06980dfa7 | ||
|
|
66785cc05c | ||
|
|
05a38612b0 | ||
|
|
d27f4bae39 | ||
|
|
8d8c2f6ffe | ||
|
|
51d3cb951d | ||
|
|
e74b1736a1 | ||
|
|
f07c1ceaa5 | ||
|
|
63b2206ad0 | ||
|
|
27feead2f8 | ||
|
|
c782195662 | ||
|
|
0f621c2c7d | ||
|
|
a9e4574261 | ||
|
|
0229c386c5 | ||
|
|
a7b3e33078 | ||
|
|
e19a64c7ef | ||
|
|
1cb4ad8de9 | ||
|
|
6ed068a71a | ||
|
|
708e6c18b0 | ||
|
|
b943890484 | ||
|
|
a1125ad4df | ||
|
|
a8b150c595 | ||
|
|
665cbcec4b | ||
|
|
7c600440f7 | ||
|
|
e0c6f556e8 | ||
|
|
de23687d16 | ||
|
|
4cea74c73b | ||
|
|
a921d8be9d | ||
|
|
094f716bf2 | ||
|
|
7d761fe3c1 | ||
|
|
cf35d8f3d7 | ||
|
|
4bb6b67188 | ||
|
|
819b18e7ba | ||
|
|
19849db573 | ||
|
|
3d4ceb292c | ||
|
|
f5a37c6c6c | ||
|
|
32c927b53f | ||
|
|
5ffc0d13a2 | ||
|
|
112627e8b2 | ||
|
|
37c1e3c218 | ||
|
|
06e9ebebd5 |
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.1.0']
|
||||
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: pylint
|
||||
name: ruff
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
@@ -11,7 +11,7 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
pip install ruff==0.1.5
|
||||
- name: Analysing the code with ruff
|
||||
run: |
|
||||
pylint vllm tests
|
||||
ruff vllm tests
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -177,3 +177,7 @@ _build/
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
|
||||
434
.pylintrc
434
.pylintrc
@@ -1,434 +0,0 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat-in-sequence,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
logging-fstring-interpolation, # added by vLLM
|
||||
logging-not-lazy, # added by vLLM
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-class-docstring, # TODO (vLLM): enable
|
||||
missing-function-docstring,
|
||||
missing-module-docstring, # TODO (vLLM): enable
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
unspecified-encoding,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=StandardError,
|
||||
Exception,
|
||||
BaseException
|
||||
16
Dockerfile
16
Dockerfile
@@ -18,6 +18,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
# image to build pytorch extensions
|
||||
FROM dev AS build
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
# copy input files
|
||||
COPY csrc csrc
|
||||
COPY setup.py setup.py
|
||||
@@ -25,8 +30,15 @@ COPY requirements.txt requirements.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# max jobs used by Ninja to build extensions
|
||||
ENV MAX_JOBS=$max_jobs
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
|
||||
# image to run unit testing suite
|
||||
@@ -63,7 +75,7 @@ ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
|
||||
FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate fschat
|
||||
pip install accelerate
|
||||
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
||||
62
Dockerfile.rocm
Normal file
62
Dockerfile.rocm
Normal file
@@ -0,0 +1,62 @@
|
||||
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
nvidia-cuda-toolkit \
|
||||
tmux \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/app
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout 3d2b6f5 \
|
||||
&& git submodule update --init \
|
||||
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
|
||||
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN pip install xformers==0.0.23 --no-deps
|
||||
|
||||
RUN cd /app \
|
||||
&& cd vllm \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& bash patch_xformers.rocm.sh \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
14
README.md
14
README.md
@@ -10,13 +10,14 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2023/12] Added ROCm support to vLLM.
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
@@ -26,7 +27,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
vLLM is fast with:
|
||||
@@ -34,6 +35,8 @@ vLLM is fast with:
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with **PagedAttention**
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629)
|
||||
- Optimized CUDA kernels
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
@@ -43,13 +46,15 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs and AMD GPUs
|
||||
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||
@@ -58,9 +63,10 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -12,7 +14,6 @@ from vllm import LLM, SamplingParams
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
# Process all the requests in a single batch if possible.
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
@@ -20,10 +21,9 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@@ -37,28 +37,43 @@ def main(args: argparse.Namespace):
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||
|
||||
def run_to_completion(profile: bool = False):
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return latency
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
run_to_completion(profile=False)
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=args.profile_result_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile=False))
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||
|
||||
|
||||
@@ -70,7 +85,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'squeezellm', None],
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
@@ -97,5 +112,20 @@ if __name__ == '__main__':
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
help='profile the generation process of a single batch')
|
||||
parser.add_argument(
|
||||
'--profile-result-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
'path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'
|
||||
))
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -17,9 +17,8 @@ def sample_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None:
|
||||
if fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
@@ -70,6 +69,8 @@ def run_vllm(
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@@ -80,6 +81,8 @@ def run_vllm(
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@@ -202,7 +205,8 @@ def main(args: argparse.Namespace):
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype)
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@@ -242,7 +246,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'squeezellm', None],
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n",
|
||||
@@ -262,6 +266,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Maximum length of a sequence (including prompt and output). '
|
||||
'If None, will be derived from the model.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@@ -271,6 +281,9 @@ if __name__ == "__main__":
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm._C import ops
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
@@ -37,10 +37,6 @@ def main(
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
@@ -98,12 +94,12 @@ def main(
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
attention_ops.paged_attention_v1(
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
@@ -112,7 +108,7 @@ def main(
|
||||
alibi_slopes,
|
||||
)
|
||||
elif version == "v2":
|
||||
attention_ops.paged_attention_v2(
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
@@ -120,7 +116,7 @@ def main(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
m.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
m.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
@@ -18,8 +20,8 @@ __global__ void silu_and_mul_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
}
|
||||
}
|
||||
@@ -35,6 +37,7 @@ void silu_and_mul(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
@@ -57,7 +60,7 @@ __global__ void activation_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
@@ -70,6 +73,7 @@ __global__ void activation_kernel(
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
m.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
}
|
||||
@@ -15,15 +15,24 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@@ -40,7 +49,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
@@ -59,11 +68,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
return VLLM_SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
@@ -81,7 +90,7 @@ __device__ void paged_attention_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@@ -124,7 +133,8 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int kv_head_idx = head_mapping[head_idx];
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
@@ -223,7 +233,7 @@ __device__ void paged_attention_kernel(
|
||||
// The 0-th thread of each thread group already has its max qk value.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
@@ -235,10 +245,10 @@ __device__ void paged_attention_kernel(
|
||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
// Broadcast the max qk value to all threads.
|
||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
@@ -326,7 +336,7 @@ __device__ void paged_attention_kernel(
|
||||
float acc = accs[i];
|
||||
#pragma unroll
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
@@ -393,7 +403,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@@ -404,7 +414,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
@@ -422,7 +432,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@@ -432,7 +442,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
@@ -492,7 +502,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
@@ -502,10 +512,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
@@ -539,16 +549,16 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
||||
shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
@@ -568,7 +578,7 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@@ -594,7 +604,6 @@ void paged_attention_v1_launcher(
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@@ -608,6 +617,7 @@ void paged_attention_v1_launcher(
|
||||
|
||||
dim3 grid(num_heads, num_seqs, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
@@ -643,7 +653,7 @@ void paged_attention_v1_launcher(
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
@@ -673,7 +683,7 @@ void paged_attention_v1(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
@@ -700,7 +710,7 @@ void paged_attention_v1(
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
@@ -731,7 +741,7 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@@ -760,7 +770,6 @@ void paged_attention_v2_launcher(
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@@ -777,6 +786,7 @@ void paged_attention_v2_launcher(
|
||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
@@ -815,7 +825,7 @@ void paged_attention_v2_launcher(
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
@@ -848,7 +858,7 @@ void paged_attention_v2(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
@@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
float qk = sum(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
||||
@@ -21,8 +21,17 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return a + b;
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@@ -63,21 +67,47 @@ struct FloatVec<uint4> {
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
#ifndef USE_ROCM
|
||||
uint32_t b;
|
||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ float half_to_float(uint16_t h) {
|
||||
float f;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||
#else
|
||||
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||
#endif
|
||||
return f;
|
||||
}
|
||||
|
||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||
#ifndef USE_ROCM
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u32 = v;
|
||||
float2 ret;
|
||||
ret.x = half_to_float(tmp.u16[0]);
|
||||
ret.y = half_to_float(tmp.u16[1]);
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ uint16_t float_to_half(float f) {
|
||||
@@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||
#else
|
||||
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||
#endif
|
||||
return tmp.u16[0];
|
||||
}
|
||||
|
||||
@@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
tmp.u16[1] = float_to_half(f.y);
|
||||
#endif
|
||||
return tmp.u32;
|
||||
}
|
||||
@@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
// Vector addition.
|
||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
template<>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <map>
|
||||
@@ -26,22 +28,3 @@ void gather_cached_kv(
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
m.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
m.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
m.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -28,10 +30,11 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
@@ -126,6 +129,7 @@ void copy_blocks(
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
@@ -206,6 +210,7 @@ void reshape_and_cache(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
@@ -267,8 +272,8 @@ __global__ void gather_cached_kv_kernel(
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,8 +338,8 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -366,6 +371,7 @@ void gather_cached_kv(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
|
||||
28
csrc/cuda_compat.h
Normal file
28
csrc/cuda_compat.h
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
#define VLLM_LDG(arg) *(arg)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
#else
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
}
|
||||
|
||||
7
csrc/cuda_utils.h
Normal file
7
csrc/cuda_utils.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
@@ -1,3 +1,6 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
* Adapted from
|
||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
m.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
@@ -76,6 +77,7 @@ void rms_norm(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
@@ -101,6 +103,7 @@ void fused_add_rms_norm(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
|
||||
91
csrc/ops.h
Normal file
91
csrc/ops.h
Normal file
@@ -0,0 +1,91 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama);
|
||||
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm);
|
||||
@@ -1,16 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
@@ -19,14 +21,14 @@ inline __device__ void apply_rotary_embedding(
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
@@ -42,8 +44,8 @@ __global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int query_stride,
|
||||
const int key_stride,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
@@ -59,7 +61,7 @@ __global__ void rotary_embedding_kernel(
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
@@ -68,7 +70,7 @@ __global__ void rotary_embedding_kernel(
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
@@ -88,11 +90,12 @@ void rotary_embedding(
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int query_stride = query.stride(-2);
|
||||
int key_stride = key.stride(-2);
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
|
||||
84
csrc/pybind.cpp
Normal file
84
csrc/pybind.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
#include "cache.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// vLLM custom ops
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Quantization ops
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
#endif
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
}
|
||||
64
csrc/quantization/gptq/compat.cuh
Normal file
64
csrc/quantization/gptq/compat.cuh
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
||||
151
csrc/quantization/gptq/matrix_view.cuh
Normal file
151
csrc/quantization/gptq/matrix_view.cuh
Normal file
@@ -0,0 +1,151 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
||||
875
csrc/quantization/gptq/q_gemm.cu
Normal file
875
csrc/quantization/gptq/q_gemm.cu
Normal file
@@ -0,0 +1,875 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "compat.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "qdq_4.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
#define MAX_ALT_GEMM_ROWS 8
|
||||
#define THREADS_X 32
|
||||
#define THREADS_Y 32
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipblas/hipblas.h>
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hadd2(result, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int*
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int* __restrict__ b_q_perm
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
|
||||
half a0;
|
||||
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
else a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero output
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (blockIdx.z == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
int groupsize = size_k / groups;
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// a, b offset
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
int zeros[4];
|
||||
float scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
// Column result
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
int k = offset_k;
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
a_ptr += 8;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_q_perm,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
int groups
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
a,
|
||||
b_q_weight,
|
||||
b_gptq_qzeros,
|
||||
b_gptq_scales,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
groups,
|
||||
b_q_perm
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
__global__ void reconstruct_exllama_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const int* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groups,
|
||||
half* __restrict__ b
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
|
||||
if (b_q_perm)
|
||||
{
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
}
|
||||
|
||||
// Column
|
||||
int n = offset_n + t * 4;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
int groupsize = size_k / groups;
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// b offset
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Initial zeros/scale
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4][4];
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
b_ptr += size_n;
|
||||
//half* dqh = (half*)dq;
|
||||
if (b_q_perm)
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void reconstruct_exllama
|
||||
(
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_q_perm,
|
||||
half* out,
|
||||
int height,
|
||||
int width,
|
||||
int groups
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
b_q_weight,
|
||||
b_q_perm,
|
||||
b_gptq_qzeros,
|
||||
b_gptq_scales,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
out
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
__global__ void gemm_half_q_half_alt_kernel(
|
||||
const half2* __restrict__ vec,
|
||||
const uint32_t* __restrict__ mat,
|
||||
half* __restrict__ mul,
|
||||
const half* __restrict__ scales,
|
||||
const uint32_t* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int height,
|
||||
int width
|
||||
)
|
||||
{
|
||||
int zero_width = width / 8;
|
||||
int vec_height = height * 4;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
for (int m = 0; m < b_end; ++m) {
|
||||
blockvec[m][threadIdx.x] =
|
||||
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
|
||||
threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ half2 deq2[256][8];
|
||||
int val = threadIdx.x / 8;
|
||||
int off = threadIdx.x % 8;
|
||||
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
||||
deq2[val][off] = __halves2half2(
|
||||
__int2half_rn(val & 0xF), __int2half_rn(val >> 4)
|
||||
);
|
||||
}
|
||||
|
||||
if (blockIdx.z == 0)
|
||||
{
|
||||
for (int m = 0; m < b_end; m++)
|
||||
mul[(b + m) * width + w] = __int2half_rn(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int i = width * h + w;
|
||||
int g_h = h * 8;
|
||||
int k = 0;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
half2 res2;
|
||||
half res[BLOCK_M_SIZE_MAX] = {};
|
||||
|
||||
unsigned int tmp;
|
||||
while (k < h_end) {
|
||||
tmp = mat[i];
|
||||
half2 scales_tmp[4];
|
||||
half2 zeros_tmp[4];
|
||||
for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
|
||||
int g = g_idx[g_h + (k + tmp_k) * 2];
|
||||
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
|
||||
half scale_f = scales[g * width + w];
|
||||
half scale_f2 = scales[g2 * width + w];
|
||||
half2 scale = __halves2half2(scale_f, scale_f2);
|
||||
half2 zero = __halves2half2(
|
||||
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
|
||||
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
|
||||
);
|
||||
scales_tmp[tmp_k] = scale;
|
||||
zeros_tmp[tmp_k] = zero;
|
||||
}
|
||||
for (int m = 0; m < b_end; m++) {
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
|
||||
#ifndef USE_ROCM
|
||||
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
|
||||
#else
|
||||
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
|
||||
#endif
|
||||
}
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
for (int m = 0; m < b_end; m++) {
|
||||
atomicAdd(&mul[(b + m) * width + w], res[m]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gemm_half_q_half_alt
|
||||
(
|
||||
const half* a,
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_g_idx,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
|
||||
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half2*) a,
|
||||
b_q_weight,
|
||||
c,
|
||||
b_gptq_scales,
|
||||
b_gptq_qzeros,
|
||||
b_g_idx,
|
||||
size_m,
|
||||
size_k / 8,
|
||||
size_n
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
__global__ void reconstruct_gptq_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
const int height,
|
||||
const int width,
|
||||
const int group,
|
||||
half* __restrict__ out
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
int row = blockIdx.y * 8;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, group, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, group, width);
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
int group = g_idx[row + s / 4];
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void reconstruct_gptq
|
||||
(
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_g_idx,
|
||||
half* out,
|
||||
int height,
|
||||
int width,
|
||||
int groups
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.y = DIVIDE(height, 8);
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
b_q_weight,
|
||||
b_gptq_scales,
|
||||
b_gptq_qzeros,
|
||||
b_g_idx,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
out
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_g_idx,
|
||||
half* c,
|
||||
half* temp_dq,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int groups,
|
||||
bool use_exllama
|
||||
)
|
||||
{
|
||||
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
if (use_exllama) {
|
||||
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
|
||||
size_k, size_n, groups);
|
||||
}
|
||||
else
|
||||
{
|
||||
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
temp_dq, size_k, size_n, groups);
|
||||
}
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = __float2half(0.0f);
|
||||
cublasHgemm(cublas_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
size_n, size_m, size_k,
|
||||
&alpha, temp_dq, size_n,
|
||||
a, size_k,
|
||||
&beta, c, size_n);
|
||||
}
|
||||
else if (use_exllama)
|
||||
{
|
||||
// Quantized matmul
|
||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
|
||||
groups);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
|
||||
b_gptq_scales, b_g_idx, c + last_chunk * size_n,
|
||||
last_chunk_size, size_n, size_k, last_chunk_size,
|
||||
groups);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
c, size_m, size_n, size_k);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void shuffle_kernel
|
||||
(
|
||||
uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k,
|
||||
const int size_n
|
||||
)
|
||||
{
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||
}
|
||||
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const int* __restrict__ q_perm,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = q_perm[q_perm_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
|
||||
void shuffle_exllama_weight
|
||||
(
|
||||
uint32_t* q_weight,
|
||||
int* q_perm,
|
||||
int height,
|
||||
int width
|
||||
)
|
||||
{
|
||||
if (q_perm)
|
||||
{
|
||||
uint32_t* new_qweight = NULL;
|
||||
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = height / 8;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
q_weight,
|
||||
new_qweight,
|
||||
q_perm,
|
||||
height / 8,
|
||||
width
|
||||
);
|
||||
// Replace qweights
|
||||
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
// Cleanup
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(new_qweight);
|
||||
}
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = 1;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor gptq_gemm
|
||||
(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
|
||||
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
|
||||
|
||||
vllm::gptq::gemm_half_q_half_cuda
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(const half*) a.data_ptr(),
|
||||
(const uint32_t*) b_q_weight.data_ptr(),
|
||||
(const uint32_t*)b_gptq_qzeros.data_ptr(),
|
||||
(const half*) b_gptq_scales.data_ptr(),
|
||||
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
|
||||
(half*) c.data_ptr(),
|
||||
(half*) temp_dq.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
b_gptq_qzeros.size(0), // group number
|
||||
use_exllama
|
||||
);
|
||||
return c;
|
||||
}
|
||||
|
||||
void gptq_shuffle
|
||||
(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||
vllm::gptq::shuffle_exllama_weight(
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_weight.size(0) * 8,
|
||||
q_weight.size(1)
|
||||
);
|
||||
}
|
||||
235
csrc/quantization/gptq/qdq_4.cuh
Normal file
235
csrc/quantization/gptq/qdq_4.cuh
Normal file
@@ -0,0 +1,235 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#else
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z = __hmul(z, scale);
|
||||
z1[0] = __half2half2(z);
|
||||
y1[0] = __half2half2(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1)[2],
|
||||
half2(&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z1[0] = __half2half2(z);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
half2 dqh2[8];
|
||||
|
||||
uint32_t qa = q_0;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
dqh2[i] = __halves2half2(d0, d1);
|
||||
}
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#endif
|
||||
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
||||
@@ -7,6 +7,7 @@
|
||||
// half-tensor
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
@@ -20,9 +21,17 @@ __device__ inline unsigned int as_unsigned(int i) {
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
@@ -35,7 +44,11 @@ __global__ void NUQ4MatMulKernel(
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
#else
|
||||
__shared__ __half2 blockvec[blockwidth2];
|
||||
#endif
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
@@ -46,8 +59,13 @@ __global__ void NUQ4MatMulKernel(
|
||||
}
|
||||
|
||||
__half res;
|
||||
#ifndef USE_ROCM
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
#else
|
||||
__half2 res2;
|
||||
__half2 tmp2;
|
||||
#endif
|
||||
|
||||
int i;
|
||||
int k;
|
||||
@@ -68,48 +86,96 @@ __global__ void NUQ4MatMulKernel(
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
tmp2.x = __half_as_ushort(__float2half(0));
|
||||
tmp2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
#ifndef USE_ROCM
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
#else
|
||||
__half2 res3;
|
||||
res3.x = __half_as_ushort(__float2half(0));
|
||||
res3.y = __half_as_ushort(__float2half(0));
|
||||
if (col % 2 == 0) {
|
||||
res3.x = __half_as_ushort(res);
|
||||
} else {
|
||||
res3.y = __half_as_ushort(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
#else
|
||||
int tmp_addr = b * width / 2 + col / 2;
|
||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,11 +201,22 @@ void squeezellm_gemm(
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
#else
|
||||
(__half2*) vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*) mul.data_ptr<float>(),
|
||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
||||
@@ -17,13 +17,15 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
144
docs/source/getting_started/amd-installation.rst
Normal file
144
docs/source/getting_started/amd-installation.rst
Normal file
@@ -0,0 +1,144 @@
|
||||
.. _installation_rocm:
|
||||
|
||||
Installation with ROCm
|
||||
======================
|
||||
|
||||
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
|
||||
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
|
||||
Data types currently supported in ROCm are FP16 and BF16.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 -- 3.11 (Verified on 3.10)
|
||||
* GPU: MI200s
|
||||
* Pytorch 2.0.1/2.1.1/2.2
|
||||
* ROCm 5.7
|
||||
|
||||
Installation options:
|
||||
|
||||
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
|
||||
#. :ref:`Build from source <build_from_source_rocm>`
|
||||
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
|
||||
|
||||
.. _quick_start_docker_rocm:
|
||||
|
||||
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
embeddedllminfo/vllm-rocm \
|
||||
bash
|
||||
|
||||
|
||||
.. _build_from_source_rocm:
|
||||
|
||||
Option 2: Build from source
|
||||
---------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
|
||||
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.23 --no-deps
|
||||
$ bash patch_xformers.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
||||
|
||||
|
||||
.. _build_from_source_docker_rocm:
|
||||
|
||||
Option 3: Build from source with docker
|
||||
-----------------------------------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
vllm-rocm \
|
||||
bash
|
||||
|
||||
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.23 --no-deps
|
||||
$ bash patch_xformers.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes.
|
||||
@@ -3,14 +3,14 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||
vLLM is a Python library that also contains pre-compiled C++ and CUDA (12.1) binaries.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 -- 3.11
|
||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.)
|
||||
|
||||
Install with pip
|
||||
----------------
|
||||
@@ -20,12 +20,32 @@ You can install vLLM using pip:
|
||||
.. code-block:: console
|
||||
|
||||
$ # (Optional) Create a new conda environment.
|
||||
$ conda create -n myenv python=3.8 -y
|
||||
$ conda create -n myenv python=3.9 -y
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Install vLLM.
|
||||
$ # Install vLLM with CUDA 12.1.
|
||||
$ pip install vllm
|
||||
|
||||
.. note::
|
||||
|
||||
As of now, vLLM's binaries are compiled on CUDA 12.1 by default.
|
||||
However, you can install vLLM with CUDA 11.8 by running:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Install vLLM with CUDA 11.8.
|
||||
$ export VLLM_VERSION=0.2.4
|
||||
$ export PYTHON_VERSION=39
|
||||
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
||||
|
||||
$ # Re-install PyTorch with CUDA 11.8.
|
||||
$ pip uninstall torch -y
|
||||
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
$ # Re-install xFormers with CUDA 11.8.
|
||||
$ pip uninstall xformers -y
|
||||
$ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
|
||||
.. _build_from_source:
|
||||
|
||||
@@ -45,6 +65,5 @@ You can also build and install vLLM from source:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
|
||||
|
||||
@@ -107,6 +107,7 @@ OpenAI-Compatible Server
|
||||
------------------------
|
||||
|
||||
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||
|
||||
Start the server:
|
||||
|
||||
@@ -122,7 +123,13 @@ Use model from www.modelscope.cn
|
||||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
|
||||
|
||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model facebook/opt-125m \
|
||||
$ --chat-template ./examples/template_chatml.jinja
|
||||
|
||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||
|
||||
@@ -130,6 +137,9 @@ This server can be queried in the same format as OpenAI API. For example, list t
|
||||
|
||||
$ curl http://localhost:8000/v1/models
|
||||
|
||||
Using OpenAI Completions API with vLLM
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Query the model with input prompts:
|
||||
|
||||
.. code-block:: console
|
||||
@@ -147,12 +157,65 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
completion = openai.Completion.create(model="facebook/opt-125m",
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
completion = client.completions.create(model="facebook/opt-125m",
|
||||
prompt="San Francisco is a")
|
||||
print("Completion result:", completion)
|
||||
|
||||
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||
|
||||
Using OpenAI Chat API with vLLM
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||
|
||||
Querying the model using OpenAI Chat API:
|
||||
|
||||
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to communicate with the model in a chat-like interface:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ curl http://localhost:8000/v1/chat/completions \
|
||||
$ -H "Content-Type: application/json" \
|
||||
$ -d '{
|
||||
$ "model": "facebook/opt-125m",
|
||||
$ "messages": [
|
||||
$ {"role": "system", "content": "You are a helpful assistant."},
|
||||
$ {"role": "user", "content": "Who won the world series in 2020?"}
|
||||
$ ]
|
||||
$ }'
|
||||
|
||||
Python Client Example:
|
||||
|
||||
Using the `openai` python package, you can also communicate with the model in a chat-like manner:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import OpenAI
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="facebook/opt-125m",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Tell me a joke."},
|
||||
]
|
||||
)
|
||||
print("Chat response:", chat_response)
|
||||
|
||||
For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation.
|
||||
|
||||
@@ -30,6 +30,8 @@ vLLM is fast with:
|
||||
* State-of-the-art serving throughput
|
||||
* Efficient management of attention key and value memory with **PagedAttention**
|
||||
* Continuous batching of incoming requests
|
||||
* Fast model execution with CUDA/HIP graph
|
||||
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_
|
||||
* Optimized CUDA kernels
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
@@ -39,6 +41,7 @@ vLLM is flexible and easy to use with:
|
||||
* Tensor parallelism support for distributed inference
|
||||
* Streaming outputs
|
||||
* OpenAI-compatible API server
|
||||
* Support NVIDIA GPUs and AMD GPUs
|
||||
|
||||
For more information, check out the following:
|
||||
|
||||
@@ -56,6 +59,7 @@ Documentation
|
||||
:caption: Getting Started
|
||||
|
||||
getting_started/installation
|
||||
getting_started/amd-installation
|
||||
getting_started/quickstart
|
||||
|
||||
.. toctree::
|
||||
@@ -66,6 +70,8 @@ Documentation
|
||||
serving/run_on_sky
|
||||
serving/deploying_with_triton
|
||||
serving/deploying_with_docker
|
||||
serving/serving_with_langchain
|
||||
serving/metrics
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
@@ -73,6 +79,7 @@ Documentation
|
||||
|
||||
models/supported_models
|
||||
models/adding_model
|
||||
models/engine_args
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
@@ -18,7 +18,7 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
|
||||
0. Fork the vLLM repository
|
||||
--------------------------------
|
||||
|
||||
Start by forking our `GitHub <https://github.com/vllm-project/vllm/>`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||
This gives you the ability to modify the codebase and test your model.
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ This gives you the ability to modify the codebase and test your model.
|
||||
------------------------
|
||||
|
||||
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
|
||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adapted from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||
|
||||
.. warning::
|
||||
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
||||
@@ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
||||
+ positions: torch.Tensor,
|
||||
+ kv_caches: List[KVCache],
|
||||
+ input_metadata: InputMetadata,
|
||||
+ cache_events: Optional[List[torch.cuda.Event]],
|
||||
+) -> SamplerOutput:
|
||||
+) -> Optional[SamplerOutput]:
|
||||
|
||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
|
||||
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
|
||||
|
||||
.. note::
|
||||
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||
|
||||
116
docs/source/models/engine_args.rst
Normal file
116
docs/source/models/engine_args.rst
Normal file
@@ -0,0 +1,116 @@
|
||||
.. _engine_args:
|
||||
|
||||
Engine Arguments
|
||||
================
|
||||
|
||||
Below, you can find an explanation of every engine argument for vLLM:
|
||||
|
||||
.. option:: --model <model_name_or_path>
|
||||
|
||||
Name or path of the huggingface model to use.
|
||||
|
||||
.. option:: --tokenizer <tokenizer_name_or_path>
|
||||
|
||||
Name or path of the huggingface tokenizer to use.
|
||||
|
||||
.. option:: --revision <revision>
|
||||
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
||||
|
||||
.. option:: --tokenizer-revision <revision>
|
||||
|
||||
The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
||||
|
||||
.. option:: --tokenizer-mode {auto,slow}
|
||||
|
||||
The tokenizer mode.
|
||||
|
||||
* "auto" will use the fast tokenizer if available.
|
||||
* "slow" will always use the slow tokenizer.
|
||||
|
||||
.. option:: --trust-remote-code
|
||||
|
||||
Trust remote code from huggingface.
|
||||
|
||||
.. option:: --download-dir <directory>
|
||||
|
||||
Directory to download and load the weights, default to the default cache dir of huggingface.
|
||||
|
||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
|
||||
|
||||
The format of the model weights to load.
|
||||
|
||||
* "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
|
||||
* "pt" will load the weights in the pytorch bin format.
|
||||
* "safetensors" will load the weights in the safetensors format.
|
||||
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
||||
* "dummy" will initialize the weights with random values, mainly for profiling.
|
||||
|
||||
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
||||
|
||||
Data type for model weights and activations.
|
||||
|
||||
* "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||
* "half" for FP16. Recommended for AWQ quantization.
|
||||
* "float16" is the same as "half".
|
||||
* "bfloat16" for a balance between precision and range.
|
||||
* "float" is shorthand for FP32 precision.
|
||||
* "float32" for FP32 precision.
|
||||
|
||||
.. option:: --max-model-len <length>
|
||||
|
||||
Model context length. If unspecified, will be automatically derived from the model config.
|
||||
|
||||
.. option:: --worker-use-ray
|
||||
|
||||
Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
|
||||
|
||||
.. option:: --pipeline-parallel-size (-pp) <size>
|
||||
|
||||
Number of pipeline stages.
|
||||
|
||||
.. option:: --tensor-parallel-size (-tp) <size>
|
||||
|
||||
Number of tensor parallel replicas.
|
||||
|
||||
.. option:: --max-parallel-loading-workers <workers>
|
||||
|
||||
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
|
||||
|
||||
.. option:: --block-size {8,16,32}
|
||||
|
||||
Token block size for contiguous chunks of tokens.
|
||||
|
||||
.. option:: --seed <seed>
|
||||
|
||||
Random seed for operations.
|
||||
|
||||
.. option:: --swap-space <size>
|
||||
|
||||
CPU swap space size (GiB) per GPU.
|
||||
|
||||
.. option:: --gpu-memory-utilization <fraction>
|
||||
|
||||
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
|
||||
For example, a value of 0.5 would imply 50% GPU memory utilization.
|
||||
If unspecified, will use the default value of 0.9.
|
||||
|
||||
.. option:: --max-num-batched-tokens <tokens>
|
||||
|
||||
Maximum number of batched tokens per iteration.
|
||||
|
||||
.. option:: --max-num-seqs <sequences>
|
||||
|
||||
Maximum number of sequences per iteration.
|
||||
|
||||
.. option:: --max-paddings <paddings>
|
||||
|
||||
Maximum number of paddings in a batch.
|
||||
|
||||
.. option:: --disable-log-stats
|
||||
|
||||
Disable logging statistics.
|
||||
|
||||
.. option:: --quantization (-q) {awq,squeezellm,None}
|
||||
|
||||
Method used to quantize the weights.
|
||||
@@ -19,10 +19,13 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||
* - :code:`BaiChuanForCausalLM`
|
||||
- Baichuan
|
||||
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||
- :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc.
|
||||
* - :code:`ChatGLMModel`
|
||||
- ChatGLM
|
||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||
* - :code:`DeciLMForCausalLM`
|
||||
- DeciLM
|
||||
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
||||
* - :code:`BloomForCausalLM`
|
||||
- BLOOM, BLOOMZ, BLOOMChat
|
||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||
@@ -50,6 +53,9 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`MistralForCausalLM`
|
||||
- Mistral, Mistral-Instruct
|
||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||
* - :code:`MixtralForCausalLM`
|
||||
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
@@ -57,8 +63,8 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
* - :code:`PhiForCausalLM`
|
||||
- Phi-1.5
|
||||
- :code:`microsoft/phi-1_5`, etc.
|
||||
- Phi
|
||||
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
@@ -70,6 +76,9 @@ If your model uses one of the above model architectures, you can seamlessly run
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
|
||||
|
||||
.. note::
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
.. tip::
|
||||
The easiest way to check if your model is supported is to run the program below:
|
||||
|
||||
@@ -81,12 +90,17 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
To use model from www.modelscope.cn
|
||||
If vLLM successfully generates text, it indicates that your model is supported.
|
||||
|
||||
.. tip::
|
||||
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ export VLLM_USE_MODELSCOPE=True
|
||||
|
||||
And use with :code:`trust_remote_code=True`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
@@ -94,5 +108,3 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
If vLLM successfully generates text, it indicates that your model is supported.
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
AutoAWQ
|
||||
==================
|
||||
|
||||
.. warning::
|
||||
|
||||
Please note that AWQ support in vLLM is under-optimized at the moment. We would recommend using the unquantized version of the model for better
|
||||
accuracy and higher throughput. Currently, you can use AWQ as a way to reduce memory footprint. As of now, it is more suitable for low latency
|
||||
inference with small number of concurrent requests. vLLM's AWQ implementation have lower throughput than unquantized version.
|
||||
|
||||
To create a new 4-bit quantized model, you can leverage `AutoAWQ <https://github.com/casper-hansen/AutoAWQ>`_.
|
||||
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
|
||||
The main benefits are lower latency and memory usage.
|
||||
|
||||
@@ -3,11 +3,41 @@
|
||||
Deploying with Docker
|
||||
============================
|
||||
|
||||
vLLM offers official docker image for deployment.
|
||||
The image can be used to run OpenAI compatible server.
|
||||
The image is available on Docker Hub as `vllm/vllm-openai <https://hub.docker.com/r/vllm/vllm-openai/tags>`_.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker run --runtime nvidia --gpus all \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
|
||||
-p 8000:8000 \
|
||||
--ipc=host \
|
||||
vllm/vllm-openai:latest \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
You can either use the ``ipc=host`` flag or ``--shm-size`` flag to allow the
|
||||
container to access the host's shared memory. vLLM uses PyTorch, which uses shared
|
||||
memory to share data between processes under the hood, particularly for tensor parallel inference.
|
||||
|
||||
|
||||
You can build and run vLLM from source via the provided dockerfile. To build vLLM:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ DOCKER_BUILDKIT=1 docker build . --target vllm --tag vllm --build-arg max_jobs=8
|
||||
$ DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai # optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
By default vLLM will build for all GPU types for widest distribution. If you are just building for the
|
||||
current GPU type the machine is running on, you can add the argument ``--build-arg torch_cuda_arch_list=""``
|
||||
for vLLM to find the current GPU type and build for that.
|
||||
|
||||
|
||||
To run vLLM:
|
||||
|
||||
@@ -17,5 +47,5 @@ To run vLLM:
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
-p 8000:8000 \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
|
||||
vllm <args...>
|
||||
vllm/vllm-openai <args...>
|
||||
|
||||
|
||||
13
docs/source/serving/metrics.rst
Normal file
13
docs/source/serving/metrics.rst
Normal file
@@ -0,0 +1,13 @@
|
||||
Production Metrics
|
||||
==================
|
||||
|
||||
vLLM exposes a number of metrics that can be used to monitor the health of the
|
||||
system. These metrics are exposed via the `/metrics` endpoint on the vLLM
|
||||
OpenAI compatible API server.
|
||||
|
||||
The following metrics are exposed:
|
||||
|
||||
.. literalinclude:: ../../../vllm/engine/metrics.py
|
||||
:language: python
|
||||
:start-after: begin-metrics-definitions
|
||||
:end-before: end-metrics-definitions
|
||||
@@ -55,7 +55,7 @@ Start the serving the LLaMA-13B model on an A100 GPU:
|
||||
|
||||
$ sky launch serving.yaml
|
||||
|
||||
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||
Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
|
||||
31
docs/source/serving/serving_with_langchain.rst
Normal file
31
docs/source/serving/serving_with_langchain.rst
Normal file
@@ -0,0 +1,31 @@
|
||||
.. _run_on_langchain:
|
||||
|
||||
Serving with Langchain
|
||||
============================
|
||||
|
||||
vLLM is also available via `Langchain <https://github.com/langchain-ai/langchain>`_ .
|
||||
|
||||
To install langchain, run
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install langchain -q
|
||||
|
||||
To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langchain``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import VLLM
|
||||
|
||||
llm = VLLM(model="mosaicml/mpt-7b",
|
||||
trust_remote_code=True, # mandatory for hf models
|
||||
max_new_tokens=128,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
temperature=0.8,
|
||||
# tensor_parallel_size=... # for distributed inference
|
||||
)
|
||||
|
||||
print(llm("What is the capital of France ?"))
|
||||
|
||||
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/llms/vllm.ipynb>`_ for more details.
|
||||
@@ -47,6 +47,6 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = build_demo()
|
||||
demo.queue(concurrency_count=100).launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
demo.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
# List models API
|
||||
models = openai.Model.list()
|
||||
print("Models:", models)
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
model = models["data"][0]["id"]
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# Chat completion API
|
||||
chat_completion = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
@@ -27,7 +28,10 @@ chat_completion = openai.ChatCompletion.create(
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}])
|
||||
}],
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
# List models API
|
||||
models = openai.Model.list()
|
||||
print("Models:", models)
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
model = models["data"][0]["id"]
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# Completion API
|
||||
stream = False
|
||||
completion = openai.Completion.create(
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt="A robot may not injure a human being",
|
||||
echo=False,
|
||||
n=2,
|
||||
stream=stream,
|
||||
logprobs=3)
|
||||
logprobs=3
|
||||
)
|
||||
|
||||
print("Completion results:")
|
||||
if stream:
|
||||
|
||||
29
examples/template_alpaca.jinja
Normal file
29
examples/template_alpaca.jinja
Normal file
@@ -0,0 +1,29 @@
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
### Instruction:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
### Response:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
### Input:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
### Response:
|
||||
{% endif %}
|
||||
2
examples/template_chatml.jinja
Normal file
2
examples/template_chatml.jinja
Normal file
@@ -0,0 +1,2 @@
|
||||
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||
30
examples/template_inkbot.jinja
Normal file
30
examples/template_inkbot.jinja
Normal file
@@ -0,0 +1,30 @@
|
||||
<#meta#>
|
||||
- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }}
|
||||
- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }}
|
||||
<#system#>
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
<#chat#>
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
<#user#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
<#bot#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
<#user_context#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
<#bot#>
|
||||
{% endif %}
|
||||
16
format.sh
16
format.sh
@@ -7,7 +7,7 @@
|
||||
# # Format files that differ from origin/main.
|
||||
# bash format.sh
|
||||
|
||||
# # Commit changed files with message 'Run yapf and pylint'
|
||||
# # Commit changed files with message 'Run yapf and ruff'
|
||||
#
|
||||
#
|
||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||
@@ -22,7 +22,7 @@ ROOT="$(git rev-parse --show-toplevel)"
|
||||
builtin cd "$ROOT" || exit 1
|
||||
|
||||
YAPF_VERSION=$(yapf --version | awk '{print $2}')
|
||||
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
|
||||
RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||
|
||||
# # params: tool name, tool version, required version
|
||||
@@ -34,7 +34,7 @@ tool_version_check() {
|
||||
}
|
||||
|
||||
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
||||
|
||||
YAPF_FLAGS=(
|
||||
@@ -95,14 +95,14 @@ echo 'vLLM yapf: Done'
|
||||
|
||||
# Lint specified files
|
||||
lint() {
|
||||
pylint "$@"
|
||||
ruff "$@"
|
||||
}
|
||||
|
||||
# Lint files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autolint yet.
|
||||
lint_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause pylint to receive 0 positional arguments, making it hang
|
||||
# could cause ruff to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
@@ -111,13 +111,13 @@ lint_changed() {
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
pylint
|
||||
ruff
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Run Pylint
|
||||
echo 'vLLM Pylint:'
|
||||
# Run Ruff
|
||||
echo 'vLLM Ruff:'
|
||||
## This flag lints individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
|
||||
33
patch_xformers.rocm.sh
Normal file
33
patch_xformers.rocm.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
XFORMERS_VERSION="0.0.23"
|
||||
|
||||
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
|
||||
|
||||
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
|
||||
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
|
||||
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
|
||||
|
||||
echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
|
||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
|
||||
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
|
||||
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
else
|
||||
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
|
||||
fi
|
||||
|
||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
|
||||
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
|
||||
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
else
|
||||
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
|
||||
fi
|
||||
@@ -1,9 +1,34 @@
|
||||
[build-system]
|
||||
# Should be mirrored in requirements-build.txt
|
||||
requires = [
|
||||
"ninja",
|
||||
"packaging",
|
||||
"setuptools",
|
||||
"torch >= 2.1.0",
|
||||
"setuptools >= 49.4.0",
|
||||
"torch == 2.1.2",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
# "UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
# "I",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# line too long, handled by black formatting
|
||||
"E501",
|
||||
]
|
||||
|
||||
6
requirements-build.txt
Normal file
6
requirements-build.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# Should be mirrored in pyproject.toml
|
||||
ninja
|
||||
packaging
|
||||
setuptools>=49.4.0
|
||||
torch==2.1.2
|
||||
wheel
|
||||
@@ -1,6 +1,7 @@
|
||||
# formatting
|
||||
yapf==0.32.0
|
||||
pylint==2.8.2
|
||||
toml==0.10.2
|
||||
ruff==0.1.5
|
||||
|
||||
# type checking
|
||||
mypy==0.991
|
||||
|
||||
13
requirements-rocm.txt
Normal file
13
requirements-rocm.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
ninja # For faster builds.
|
||||
typing-extensions>=4.8.0
|
||||
starlette
|
||||
psutil
|
||||
ray >= 2.5.1
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
tokenizers>=0.15.0
|
||||
transformers >= 4.36.0 # Required for Mixtral.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
||||
@@ -1,14 +1,12 @@
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray >= 2.5.1
|
||||
pandas # Required for Ray data.
|
||||
pyarrow # Required for Ray data.
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
einops # Required for phi-1_5
|
||||
torch >= 2.1.0
|
||||
transformers >= 4.34.0 # Required for Mistral.
|
||||
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
|
||||
torch == 2.1.2
|
||||
transformers >= 4.36.0 # Required for Mixtral.
|
||||
xformers == 0.0.23.post1 # Required for CUDA 12.1.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
||||
|
||||
13
rocm_patch/commonpy_xformers-0.0.23.rocm.patch
Normal file
13
rocm_patch/commonpy_xformers-0.0.23.rocm.patch
Normal file
@@ -0,0 +1,13 @@
|
||||
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
|
||||
+++ common.py 2023-11-28 16:14:19.846233146 +0000
|
||||
@@ -298,8 +298,8 @@
|
||||
dtype = d.query.dtype
|
||||
if device_type not in cls.SUPPORTED_DEVICES:
|
||||
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
|
||||
- if device_type == "cuda" and not _built_with_cuda:
|
||||
- reasons.append("xFormers wasn't build with CUDA support")
|
||||
+ #if device_type == "cuda" and not _built_with_cuda:
|
||||
+ # reasons.append("xFormers wasn't build with CUDA support")
|
||||
if device_type == "cuda":
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
|
||||
152
rocm_patch/flashpy_xformers-0.0.23.rocm.patch
Normal file
152
rocm_patch/flashpy_xformers-0.0.23.rocm.patch
Normal file
@@ -0,0 +1,152 @@
|
||||
--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
|
||||
+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
|
||||
@@ -36,44 +36,44 @@
|
||||
|
||||
FLASH_VERSION = "0.0.0"
|
||||
try:
|
||||
- try:
|
||||
- from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
- from ..._cpp_lib import _build_metadata
|
||||
-
|
||||
- if _build_metadata is not None:
|
||||
- FLASH_VERSION = _build_metadata.flash_version
|
||||
- except ImportError:
|
||||
- import flash_attn
|
||||
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
-
|
||||
- FLASH_VERSION = flash_attn.__version__
|
||||
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
|
||||
- if (
|
||||
- flash_ver_parsed != (2, 3, 6)
|
||||
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
|
||||
- ):
|
||||
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
|
||||
+ #try:
|
||||
+ # from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
+ # from ..._cpp_lib import _build_metadata
|
||||
+
|
||||
+ # if _build_metadata is not None:
|
||||
+ # FLASH_VERSION = _build_metadata.flash_version
|
||||
+ #except ImportError:
|
||||
+ import flash_attn
|
||||
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
+
|
||||
+ FLASH_VERSION = flash_attn.__version__
|
||||
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
|
||||
+ # if (
|
||||
+ # flash_ver_parsed != (2, 3, 6)
|
||||
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
|
||||
+ # ):
|
||||
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
|
||||
|
||||
# create library so that flash-attn goes through the PyTorch Dispatcher
|
||||
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
-
|
||||
- _flash_lib.define(
|
||||
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
|
||||
- "int max_seqlen_q, int max_seqlen_k, "
|
||||
- "float p, float softmax_scale, "
|
||||
- "bool is_causal, int window_left, "
|
||||
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
||||
- )
|
||||
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
|
||||
- _flash_lib.define(
|
||||
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
- "int max_seqlen_q, int max_seqlen_k, "
|
||||
- "float p, float softmax_scale, bool is_causal, "
|
||||
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
||||
- )
|
||||
+ #_flash_lib.define(
|
||||
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
|
||||
+ # "int max_seqlen_q, int max_seqlen_k, "
|
||||
+ # "float p, float softmax_scale, "
|
||||
+ # "bool is_causal, int window_left, "
|
||||
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
||||
+ #)
|
||||
+
|
||||
+ #_flash_lib.define(
|
||||
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
+ # "int max_seqlen_q, int max_seqlen_k, "
|
||||
+ # "float p, float softmax_scale, bool is_causal, "
|
||||
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
||||
+ #)
|
||||
|
||||
def _flash_fwd(
|
||||
query,
|
||||
@@ -111,8 +111,8 @@
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
- window_left, # window_size_left
|
||||
- window_right, # window_size_right
|
||||
+ # window_left, # window_size_left
|
||||
+ # window_right, # window_size_right
|
||||
return_softmax,
|
||||
None, # rng
|
||||
)
|
||||
@@ -134,15 +134,15 @@
|
||||
out,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
- seqused_k,
|
||||
+ # seqused_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
False,
|
||||
is_causal,
|
||||
- window_left,
|
||||
- window_right,
|
||||
+ # window_left,
|
||||
+ # window_right,
|
||||
return_softmax,
|
||||
None,
|
||||
)
|
||||
@@ -184,8 +184,8 @@
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
- window_left,
|
||||
- window_right,
|
||||
+ # window_left,
|
||||
+ # window_right,
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
@@ -208,15 +208,15 @@
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
is_causal,
|
||||
- window_left,
|
||||
- window_right,
|
||||
+ # window_left,
|
||||
+ # window_right,
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -400,7 +400,7 @@
|
||||
implementation.
|
||||
"""
|
||||
|
||||
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
|
||||
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
||||
296
setup.py
296
setup.py
@@ -8,27 +8,83 @@ import warnings
|
||||
from packaging.version import parse, Version
|
||||
import setuptools
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
MAIN_CUDA_VERSION = "12.1"
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
|
||||
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return torch.version.cuda is not None
|
||||
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||
# TODO(woosuk): Should we use -O3?
|
||||
NVCC_FLAGS = ["-O2", "-std=c++17"]
|
||||
|
||||
if _is_hip():
|
||||
if ROCM_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find ROCM_HOME. ROCm must be available to build the package."
|
||||
)
|
||||
NVCC_FLAGS += ["-DUSE_ROCM"]
|
||||
|
||||
if _is_cuda() and CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
||||
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
def get_amdgpu_offload_arch():
|
||||
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
|
||||
try:
|
||||
output = subprocess.check_output([command])
|
||||
return output.decode('utf-8').strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = f"Error: {e}"
|
||||
raise RuntimeError(error_message) from e
|
||||
except FileNotFoundError as e:
|
||||
# If the command is not found, print an error message
|
||||
error_message = f"The command {command} was not found."
|
||||
raise RuntimeError(error_message) from e
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True)
|
||||
|
||||
# Check if the command was executed successfully
|
||||
if result.returncode != 0:
|
||||
print("Error running 'hipcc --version'")
|
||||
return None
|
||||
|
||||
# Extract the version using a regular expression
|
||||
match = re.search(r'HIP version: (\S+)', result.stdout)
|
||||
if match:
|
||||
# Return the version string
|
||||
return match.group(1)
|
||||
else:
|
||||
print("Could not find HIP version in the output")
|
||||
return None
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
@@ -61,27 +117,30 @@ def get_torch_arch_list() -> Set[str]:
|
||||
return set()
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
|
||||
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
|
||||
{s + "+PTX"
|
||||
for s in NVIDIA_SUPPORTED_ARCHS})
|
||||
arch_list = torch_arch_list.intersection(valid_archs)
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA architectures are: {valid_archs}.")
|
||||
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"{valid_archs}.")
|
||||
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
|
||||
f"{valid_archs}.",
|
||||
stacklevel=2)
|
||||
return arch_list
|
||||
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
if not compute_capabilities:
|
||||
if _is_cuda() and not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
device_count = torch.cuda.device_count()
|
||||
@@ -92,135 +151,91 @@ if not compute_capabilities:
|
||||
"GPUs with compute capability below 7.0 are not supported.")
|
||||
compute_capabilities.add(f"{major}.{minor}")
|
||||
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
if any(cc.startswith("8.6") for cc in compute_capabilities):
|
||||
if _is_cuda():
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.0 or higher is required to build the package.")
|
||||
if (nvcc_cuda_version < Version("11.1")
|
||||
and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.")
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.",
|
||||
stacklevel=2)
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += [
|
||||
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||
]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
|
||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
|
||||
elif _is_hip():
|
||||
amd_arch = get_amdgpu_offload_arch()
|
||||
if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
num_threads = min(os.cpu_count(), 8)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
||||
f"amdgpu_arch_found: {amd_arch}")
|
||||
|
||||
ext_modules = []
|
||||
|
||||
# Cache operations.
|
||||
cache_extension = CUDAExtension(
|
||||
name="vllm.cache_ops",
|
||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cache_extension)
|
||||
vllm_extension_sources = [
|
||||
"csrc/cache_kernels.cu",
|
||||
"csrc/attention/attention_kernels.cu",
|
||||
"csrc/pos_encoding_kernels.cu",
|
||||
"csrc/activation_kernels.cu",
|
||||
"csrc/layernorm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
"csrc/quantization/gptq/q_gemm.cu",
|
||||
"csrc/cuda_utils_kernels.cu",
|
||||
"csrc/pybind.cpp",
|
||||
]
|
||||
|
||||
# Attention kernels.
|
||||
attention_extension = CUDAExtension(
|
||||
name="vllm.attention_ops",
|
||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(attention_extension)
|
||||
if _is_cuda():
|
||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||
|
||||
# Positional encoding kernels.
|
||||
positional_encoding_extension = CUDAExtension(
|
||||
name="vllm.pos_encoding_ops",
|
||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||
vllm_extension = CUDAExtension(
|
||||
name="vllm._C",
|
||||
sources=vllm_extension_sources,
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(positional_encoding_extension)
|
||||
|
||||
# Layer normalization kernels.
|
||||
layernorm_extension = CUDAExtension(
|
||||
name="vllm.layernorm_ops",
|
||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(layernorm_extension)
|
||||
|
||||
# Activation kernels.
|
||||
activation_extension = CUDAExtension(
|
||||
name="vllm.activation_ops",
|
||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(activation_extension)
|
||||
|
||||
# Quantization kernels.
|
||||
quantization_extension = CUDAExtension(
|
||||
name="vllm.quantization_ops",
|
||||
sources=[
|
||||
"csrc/quantization.cpp",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(quantization_extension)
|
||||
|
||||
# Misc. CUDA utils.
|
||||
cuda_utils_extension = CUDAExtension(
|
||||
name="vllm.cuda_utils",
|
||||
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cuda_utils_extension)
|
||||
ext_modules.append(vllm_extension)
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
@@ -242,10 +257,19 @@ def find_version(filepath: str) -> str:
|
||||
|
||||
def get_vllm_version() -> str:
|
||||
version = find_version(get_path("vllm", "__init__.py"))
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
|
||||
if _is_hip():
|
||||
# Get the HIP version
|
||||
hipcc_version = get_hipcc_rocm_version()
|
||||
if hipcc_version != MAIN_CUDA_VERSION:
|
||||
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
||||
version += f"+rocm{rocm_version_str}"
|
||||
else:
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@@ -260,8 +284,12 @@ def read_readme() -> str:
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
"""Get Python package dependencies from requirements.txt."""
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
if _is_hip():
|
||||
with open(get_path("requirements-rocm.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
else:
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
return requirements
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ app = vllm.entrypoints.api_server.app
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
@@ -8,11 +8,11 @@ import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _query_server(prompt: str) -> dict:
|
||||
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
|
||||
response = requests.post("http://localhost:8000/generate",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": 100,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0,
|
||||
"ignore_eos": True
|
||||
})
|
||||
@@ -20,11 +20,14 @@ def _query_server(prompt: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def _query_server_long(prompt: str) -> dict:
|
||||
return _query_server(prompt, max_tokens=500)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server():
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
# pylint: disable=consider-using-with
|
||||
uvicorn_process = subprocess.Popen([
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m"
|
||||
@@ -33,7 +36,6 @@ def api_server():
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
def test_api_server(api_server):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
@@ -46,14 +48,14 @@ def test_api_server(api_server):
|
||||
"""
|
||||
with Pool(32) as pool:
|
||||
# Wait until the server is ready
|
||||
prompts = ["Hello world"] * 1
|
||||
prompts = ["warm up"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
# pylint: disable=bare-except
|
||||
try:
|
||||
for result in pool.map(_query_server, prompts):
|
||||
for r in pool.map(_query_server, prompts):
|
||||
result = r
|
||||
break
|
||||
except:
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(1)
|
||||
|
||||
# Actual tests start here
|
||||
@@ -66,12 +68,14 @@ def test_api_server(api_server):
|
||||
assert num_aborted_requests == 0
|
||||
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
prompts = ["test prompt"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
with Pool(32) as pool:
|
||||
# Cancel requests
|
||||
pool.map_async(_query_server, prompts)
|
||||
prompts = ["canceled requests"] * 100
|
||||
pool.map_async(_query_server_long, prompts)
|
||||
time.sleep(0.01)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
@@ -84,6 +88,6 @@ def test_api_server(api_server):
|
||||
# check that server still runs after cancellations
|
||||
with Pool(32) as pool:
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
prompts = ["test prompt after canceled"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
119
tests/async_engine/test_openai_server.py
Normal file
119
tests/async_engine/test_openai_server.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from vllm.entrypoints.openai.api_server import *
|
||||
|
||||
# Define models, templates, and their corresponding expected outputs
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||
("facebook/opt-125m", None, True,
|
||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||
("facebook/opt-125m", None, False,
|
||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||
("facebook/opt-125m", "../../examples/template_chatml.jinja", True,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""),
|
||||
("facebook/opt-125m", "../../examples/template_chatml.jinja", False,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of""")
|
||||
]
|
||||
|
||||
TEST_MESSAGES = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'Hi there!'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
]
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockTokenizer:
|
||||
chat_template = None
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
# Testing chatml template
|
||||
template = "../../examples/template_chatml.jinja"
|
||||
mock_args = Namespace(chat_template=template)
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
# Call the function with the mocked args
|
||||
load_chat_template(mock_args, tokenizer)
|
||||
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
|
||||
|
||||
|
||||
def test_no_load_chat_template():
|
||||
# Testing chatml template
|
||||
template = "../../examples/does_not_exist"
|
||||
mock_args = Namespace(chat_template=template)
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
# Call the function with the mocked args
|
||||
load_chat_template(mock_args, tokenizer=tokenizer)
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """../../examples/does_not_exist"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model,template,add_generation_prompt,expected_output",
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||
async def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
expected_output):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||
|
||||
mock_args = Namespace(chat_template=template)
|
||||
load_chat_template(mock_args, tokenizer)
|
||||
|
||||
# Create a mock request object using keyword arguments
|
||||
mock_request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
|
||||
# Call the function and get the result
|
||||
result = tokenizer.apply_chat_template(
|
||||
conversation=mock_request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=mock_request.add_generation_prompt)
|
||||
|
||||
# Test assertion
|
||||
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
|
||||
|
||||
|
||||
def test_health_endpoint():
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
@@ -7,22 +8,33 @@ from transformers import AutoModelForCausalLM
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
# pylint: disable=line-too-long
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
"Describe the basic components of a neural network and how it can be trained.",
|
||||
"Write a short story about a robot that dreams for the first time.",
|
||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||
]
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> str:
|
||||
prompts = []
|
||||
with open(filename, "r") as f:
|
||||
prompt = f.readline()
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
return _TEST_PROMPTS
|
||||
prompts = []
|
||||
for filename in _TEST_PROMPTS:
|
||||
prompts += _read_prompts(filename)
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_long_prompts() -> List[str]:
|
||||
prompts = []
|
||||
for filename in _LONG_PROMPTS:
|
||||
prompts += _read_prompts(filename)
|
||||
return prompts
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.ray_utils import get_open_port
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_gather,
|
||||
|
||||
@@ -5,10 +5,9 @@ from transformers import AutoTokenizer
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
|
||||
TRUTH = [
|
||||
# pylint: disable=line-too-long
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||
"我很感谢你的热情"
|
||||
"Hello here, this is a simple test", # noqa: E501
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501
|
||||
"我很感谢你的热情" # noqa: E501
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
|
||||
@@ -12,6 +12,7 @@ def create_kv_caches(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
@@ -23,7 +24,7 @@ def create_kv_caches(
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device=device)
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
@@ -32,7 +33,7 @@ def create_kv_caches(
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device=device)
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
@@ -1,38 +1,35 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import get_activation
|
||||
|
||||
from vllm import activation_ops
|
||||
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||
return F.silu(x1) * x2
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ref_out = ref_silu_and_mul(x)
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
|
||||
layer = SiluAndMul()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@@ -40,19 +37,22 @@ def test_silu_and_mul(
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gelu_new(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_new(out, x)
|
||||
ref_out = get_activation("gelu_new")(x)
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
|
||||
layer = NewGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@@ -60,16 +60,19 @@ def test_gelu_new(
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_gelu_fast(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ref_out = get_activation("gelu_fast")(x)
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
|
||||
layer = FastGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm._C import ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
@@ -24,6 +24,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len, device="cuda").int()
|
||||
position_ids = torch.arange(context_len, device=query.device).int()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
@@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
@@ -115,35 +117,33 @@ def test_paged_attention(
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
@@ -154,23 +154,23 @@ def test_paged_attention(
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size, dtype,
|
||||
seed)
|
||||
seed, gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v1":
|
||||
attention_ops.paged_attention_v1(
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
@@ -194,7 +194,7 @@ def test_paged_attention(
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
@@ -202,7 +202,7 @@ def test_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
@@ -211,7 +211,7 @@ def test_paged_attention(
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unknown version: {version}"
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_output = torch.empty_like(query)
|
||||
@@ -252,7 +252,7 @@ def ref_multi_query_kv_attention(
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
|
||||
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
@@ -272,6 +272,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
@@ -279,11 +280,12 @@ def test_multi_query_kv_attention(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
@@ -297,7 +299,7 @@ def test_multi_query_kv_attention(
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import cache_ops
|
||||
from vllm._C import cache_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [83] # Arbitrary values for testing
|
||||
@@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@@ -24,6 +25,7 @@ SEEDS = [0]
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
@@ -35,11 +37,12 @@ def test_copy_blocks(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
@@ -56,7 +59,7 @@ def test_copy_blocks(
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed)
|
||||
head_size, dtype, seed, gpu_id)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
@@ -88,6 +91,7 @@ def test_copy_blocks(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
@@ -98,28 +102,29 @@ def test_reshape_and_cache(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
seed)
|
||||
seed, gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
|
||||
@@ -1,58 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import layernorm_ops
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
class RefRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
weight = torch.empty(hidden_size)
|
||||
weight.normal_(mean=1.0, std=0.1)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
scale = float(hidden_size**-0.5)
|
||||
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
x.uniform_(-scale, scale)
|
||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||
|
||||
out = torch.empty_like(x)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
ref.weight.data,
|
||||
ref.variance_epsilon,
|
||||
)
|
||||
ref_out = ref(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_out = layer._forward(x, residual)
|
||||
out = layer(x, residual)
|
||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||
# numerical errors than other operators because they involve reductions.
|
||||
# Therefore, we use a larger tolerance.
|
||||
if add_residual:
|
||||
assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
|
||||
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@@ -1,119 +1,41 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
||||
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
|
||||
NUM_HEADS = [7, 17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def apply_rope(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
||||
q_embed = (q * cos) + (rotate_fn(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_fn(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RefRotaryEmbedding(nn.Module):
|
||||
"""Reference implementation of rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
is_neox_style: bool,
|
||||
max_position_embeddings: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rotary_dim = dim
|
||||
self.is_neox_style = is_neox_style
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
if is_neox_style:
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
else:
|
||||
emb = torch.repeat_interleave(freqs, 2, -1)
|
||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
self.register_buffer("sin_cached", sin, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
query_rot = query_rot.transpose(0, 1)
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
|
||||
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
||||
return query, key
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -121,54 +43,26 @@ def test_rotary_embedding(
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope = rope.to(dtype=dtype, device=gpu_id)
|
||||
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
|
||||
query = torch.randn(num_tokens,
|
||||
positions = torch.randint(0,
|
||||
max_position, (batch_size, seq_len),
|
||||
device=gpu_id)
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
key = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(
|
||||
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
|
||||
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbedding(
|
||||
dim=rotary_dim,
|
||||
is_neox_style=is_neox_style,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device="cuda")
|
||||
ref_query, ref_key = ref_rotary_embedding(
|
||||
positions,
|
||||
query.view(num_tokens, num_heads, head_size),
|
||||
key.view(num_tokens, num_heads, head_size),
|
||||
)
|
||||
ref_query = ref_query.view(num_tokens, num_heads * head_size)
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
device=gpu_id)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions, query, key)
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||
|
||||
37
tests/models/test_mistral.py
Normal file
37
tests/models/test_mistral.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_mistral.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_long_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_long_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@@ -8,6 +8,7 @@ MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
"Deci/DeciLM-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
@@ -15,12 +16,12 @@ MODELS = [
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"microsoft/phi-1_5",
|
||||
"microsoft/phi-2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
|
||||
8
tests/prompts/example.txt
Normal file
8
tests/prompts/example.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
|
||||
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
|
||||
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
|
||||
Describe the basic components of a neural network and how it can be trained.
|
||||
Write a short story about a robot that dreams for the first time.
|
||||
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
|
||||
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
|
||||
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
|
||||
1
tests/prompts/summary.txt
Normal file
1
tests/prompts/summary.txt
Normal file
File diff suppressed because one or more lines are too long
@@ -1,4 +1,3 @@
|
||||
# pylint: disable=protected-access
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
@@ -9,7 +8,7 @@ import torch
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
@@ -20,15 +19,15 @@ class MockLogitsSampler(Sampler):
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||
lambda x, y: x):
|
||||
with patch("vllm.model_executor.layers.sampler._get_logits",
|
||||
lambda x, y: x), patch(
|
||||
"vllm.model_executor.layers.sampler._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
@@ -38,9 +37,8 @@ def _prepare_test(
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
return input_tensor, fake_logits, sampler, worker
|
||||
model_runner = ModelRunner(None, None, None)
|
||||
return input_tensor, fake_logits, sampler, model_runner
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
@@ -50,9 +48,11 @@ RANDOM_SEEDS = list(range(128))
|
||||
def test_sampler_all_greedy(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -62,11 +62,13 @@ def test_sampler_all_greedy(seed: int):
|
||||
sampling_params=SamplingParams(temperature=0, ),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
sampling_metadata=sampling_metadata)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
@@ -77,12 +79,14 @@ def test_sampler_all_greedy(seed: int):
|
||||
def test_sampler_all_random(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -95,11 +99,13 @@ def test_sampler_all_random(seed: int):
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
sampling_metadata=sampling_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
@@ -109,9 +115,10 @@ def test_sampler_all_random(seed: int):
|
||||
def test_sampler_all_beam(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -125,11 +132,13 @@ def test_sampler_all_beam(seed: int):
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
sampling_metadata=sampling_metadata)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
@@ -140,10 +149,12 @@ def test_sampler_all_beam(seed: int):
|
||||
def test_sampler_mixed(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
n = 1
|
||||
sampling_type = random.randint(0, 2)
|
||||
@@ -173,11 +184,13 @@ def test_sampler_mixed(seed: int):
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
sampling_metadata=sampling_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
@@ -189,7 +202,7 @@ def test_sampler_mixed(seed: int):
|
||||
def test_sampler_logits_processors(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
@@ -199,6 +212,7 @@ def test_sampler_logits_processors(seed: int):
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -209,11 +223,13 @@ def test_sampler_logits_processors(seed: int):
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
sampling_metadata=sampling_metadata)
|
||||
for _, sequence_output in enumerate(sampler_output):
|
||||
for idx, nth_output in enumerate(sequence_output.samples):
|
||||
assert nth_output.output_token == idx
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# pylint: disable=protected-access
|
||||
import random
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
def test_worker_prepare_inputs_for_prompt():
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
def test_prepare_prompt():
|
||||
model_runner = ModelRunner(None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (worker.block_size - 1) + 1
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_group_metadata_list.append(
|
||||
@@ -25,6 +25,7 @@ def test_worker_prepare_inputs_for_prompt():
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
max_seq_len = max(prompt_lens)
|
||||
@@ -32,12 +33,16 @@ def test_worker_prepare_inputs_for_prompt():
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
|
||||
seq_group_metadata_list)
|
||||
assert input_tokens.shape == input_positions.shape == (batch_size,
|
||||
max_seq_len)
|
||||
input_tokens, input_positions, _, return_prompt_lens = (
|
||||
model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
assert input_tokens.shape == (batch_size, max_seq_len)
|
||||
assert input_positions.shape == (batch_size, max_seq_len)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
actual = input_metadata.selected_token_indices
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.2.2"
|
||||
__version__ = "0.2.7"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
|
||||
@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory
|
||||
from vllm.utils import get_cpu_memory, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -49,6 +49,12 @@ class ModelConfig:
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -65,6 +71,8 @@ class ModelConfig:
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@@ -76,6 +84,8 @@ class ModelConfig:
|
||||
self.revision = revision
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.enforce_eager = enforce_eager
|
||||
self.max_context_len_to_capture = max_context_len_to_capture
|
||||
|
||||
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
# download model from ModelScope hub,
|
||||
@@ -95,15 +105,34 @@ class ModelConfig:
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]
|
||||
rocm_not_supported_load_format = []
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if "MixtralForCausalLM" in architectures and load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
@@ -115,7 +144,8 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
supported_quantization = ["awq", "gptq", "squeezellm"]
|
||||
rocm_not_supported_quantization = ["awq"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@@ -137,10 +167,21 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
) and self.quantization in rocm_not_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not supported "
|
||||
f"in ROCm.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_context_len_to_capture is None:
|
||||
self.max_context_len_to_capture = self.max_model_len
|
||||
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
@@ -161,6 +202,12 @@ class ModelConfig:
|
||||
"must be divisible by pipeline parallel size "
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
return getattr(self.hf_config, "sliding_window", None)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.hf_config.vocab_size
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_config.hidden_size
|
||||
|
||||
@@ -285,10 +332,12 @@ class ParallelConfig:
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: bool,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
|
||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||
if self.world_size > 1:
|
||||
@@ -356,6 +405,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
@@ -385,6 +436,14 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
if is_hip() and torch_dtype == torch.float32:
|
||||
rocm_supported_dtypes = [
|
||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||
]
|
||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
BlockTable = List[PhysicalTokenBlock]
|
||||
|
||||
|
||||
class BlockAllocator:
|
||||
"""Manages free physical token blocks for a device.
|
||||
@@ -25,7 +29,7 @@ class BlockAllocator:
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: List[PhysicalTokenBlock] = []
|
||||
self.free_blocks: BlockTable = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
@@ -50,8 +54,18 @@ class BlockAllocator:
|
||||
return len(self.free_blocks)
|
||||
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
BlockTable = List[PhysicalTokenBlock]
|
||||
class AllocStatus(enum.Enum):
|
||||
"""Result for BlockSpaceManager.can_allocate
|
||||
|
||||
1. Ok: seq_group can be allocated now.
|
||||
2. Later: seq_group cannot be allocated.
|
||||
The capacity of allocator is larger than seq_group required.
|
||||
3. Never: seq_group can never be allocated.
|
||||
The seq_group is too large to allocated in GPU.
|
||||
"""
|
||||
OK = enum.auto()
|
||||
LATER = enum.auto()
|
||||
NEVER = enum.auto()
|
||||
|
||||
|
||||
class BlockSpaceManager:
|
||||
@@ -86,23 +100,29 @@ class BlockSpaceManager:
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> bool:
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
return (num_free_gpu_blocks - num_required_blocks >=
|
||||
self.watermark_blocks)
|
||||
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||
self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
block_table: BlockTable = []
|
||||
@@ -117,7 +137,7 @@ class BlockSpaceManager:
|
||||
block_table.append(block)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in seq_group.get_seqs():
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.block_manager import BlockSpaceManager
|
||||
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
@@ -139,23 +139,35 @@ class Scheduler:
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
|
||||
assert seq_group.num_seqs() == 1, (
|
||||
waiting_seqs = seq_group.get_seqs(
|
||||
status=SequenceStatus.WAITING)
|
||||
assert len(waiting_seqs) == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
num_prompt_tokens = waiting_seqs[0].get_len()
|
||||
if num_prompt_tokens > self.prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {self.prompt_limit}")
|
||||
for seq in seq_group.get_seqs():
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.pop(0)
|
||||
continue
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
if not self.block_manager.can_allocate(seq_group):
|
||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||
if can_allocate == AllocStatus.LATER:
|
||||
break
|
||||
elif can_allocate == AllocStatus.NEVER:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds the capacity of block_manager")
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.pop(0)
|
||||
continue
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||
@@ -186,7 +198,8 @@ class Scheduler:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
num_batched_tokens=len(seq_lens) * max(seq_lens),
|
||||
num_batched_tokens=len(seq_lens) *
|
||||
max(seq_lens) if seq_lens else 0,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
@@ -306,7 +319,7 @@ class Scheduler:
|
||||
|
||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
||||
self.block_manager.allocate(seq_group)
|
||||
for seq in seq_group.get_seqs():
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _append_slot(
|
||||
@@ -350,7 +363,7 @@ class Scheduler:
|
||||
elif preemption_mode == PreemptionMode.SWAP:
|
||||
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
||||
else:
|
||||
assert False, "Invalid preemption mode."
|
||||
raise AssertionError("Invalid preemption mode.")
|
||||
|
||||
def _preempt_by_recompute(
|
||||
self,
|
||||
|
||||
@@ -22,6 +22,7 @@ class EngineArgs:
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: int = 16
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
@@ -32,6 +33,8 @@ class EngineArgs:
|
||||
revision: Optional[str] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
max_context_len_to_capture: int = 8192
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@@ -41,6 +44,10 @@ class EngineArgs:
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
# NOTE: If you update any of the arguments below, please also
|
||||
# make sure to update docs/source/models/engine_args.rst
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
@@ -128,6 +135,12 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
parser.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
help='load model sequentially in multiple batches, '
|
||||
'to avoid RAM OOM when using tensor '
|
||||
'parallel and large models')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
@@ -143,11 +156,13 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for'
|
||||
'the model executor')
|
||||
parser.add_argument(
|
||||
'--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
@@ -168,9 +183,25 @@ class EngineArgs:
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'squeezellm', None],
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
help='Method used to quantize the weights. If '
|
||||
'None, we first check the `quantization_config` '
|
||||
'attribute in the model config file. If that is '
|
||||
'None, we assume the model weights are not '
|
||||
'quantized and use `dtype` to determine the data '
|
||||
'type of the weights.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='Always use eager-mode PyTorch. If False, '
|
||||
'will use eager mode and CUDA graph in hybrid '
|
||||
'for maximal performance and flexibility.')
|
||||
parser.add_argument('--max-context-len-to-capture',
|
||||
type=int,
|
||||
default=EngineArgs.max_context_len_to_capture,
|
||||
help='maximum context length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -189,13 +220,16 @@ class EngineArgs:
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization)
|
||||
cache_config = CacheConfig(
|
||||
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
||||
getattr(model_config.hf_config, 'sliding_window', None))
|
||||
self.quantization, self.enforce_eager,
|
||||
self.max_context_len_to_capture)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space,
|
||||
model_config.get_sliding_window())
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray)
|
||||
self.worker_use_ray,
|
||||
self.max_parallel_loading_workers)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union)
|
||||
Union, AsyncIterator)
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@@ -183,49 +183,53 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return ignored
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
# Execute the model.
|
||||
output = await self._run_workers_async(
|
||||
"execute_model",
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
all_outputs = await self._run_workers_async(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
})
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
else:
|
||||
output = []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
coros = []
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Run the driver worker asynchronously.
|
||||
driver_executor = getattr(self.driver_worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(driver_executor, *driver_args, **driver_kwargs)))
|
||||
|
||||
# Run the ray workers asynchronously.
|
||||
for worker in self.workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
coros.append(
|
||||
worker.execute_method.remote(method, *args, **kwargs))
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(executor, *args, **kwargs)))
|
||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
return output
|
||||
return all_outputs
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@@ -301,7 +305,16 @@ class AsyncLLMEngine:
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||
else:
|
||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
||||
# order of the arguments.
|
||||
cache_config = args[1]
|
||||
parallel_config = args[2]
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
num_gpus = cache_config.gpu_memory_utilization
|
||||
else:
|
||||
num_gpus = 1
|
||||
engine_class = ray.remote(num_gpus=num_gpus)(
|
||||
self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self) -> bool:
|
||||
@@ -392,11 +405,12 @@ class AsyncLLMEngine:
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@@ -480,13 +494,12 @@ class AsyncLLMEngine:
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config, engine_args.engine_use_ray)
|
||||
placement_group = initialize_cluster(parallel_config,
|
||||
engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
distributed_init_method,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||
from vllm.engine.metrics import record_metrics
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceGroupOutputs,
|
||||
SequenceOutputs, SequenceStatus)
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
|
||||
|
||||
if ray:
|
||||
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -52,8 +53,6 @@ class LLMEngine:
|
||||
management.
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
distributed_init_method: The initialization method for distributed
|
||||
execution. See `torch.distributed.init_process_group` for details.
|
||||
placement_group: Ray placement group for distributed execution.
|
||||
Required for distributed execution.
|
||||
log_stats: Whether to log statistics.
|
||||
@@ -65,7 +64,6 @@ class LLMEngine:
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
@@ -83,13 +81,12 @@ class LLMEngine:
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
assert self.cache_config.sliding_window == getattr(
|
||||
self.model_config.hf_config, "sliding_window", None)
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
@@ -105,9 +102,13 @@ class LLMEngine:
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
if self.parallel_config.worker_use_ray:
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
self._init_workers_ray(placement_group)
|
||||
else:
|
||||
self._init_workers(distributed_init_method)
|
||||
self._init_workers()
|
||||
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
@@ -122,65 +123,133 @@ class LLMEngine:
|
||||
# List of (timestamp, num_tokens)
|
||||
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def _init_workers(self, distributed_init_method: str):
|
||||
def _init_workers(self):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
worker = Worker(
|
||||
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
0,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self._run_workers("init_model")
|
||||
self._run_workers("load_model")
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
num_gpus = 1
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for bundle in placement_group.bundle_specs:
|
||||
self.driver_dummy_worker: RayWorkerVllm = None
|
||||
self.workers: List[RayWorkerVllm] = []
|
||||
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorker).remote(self.model_config.trust_remote_code)
|
||||
self.workers.append(worker)
|
||||
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
else:
|
||||
self.workers.append(worker)
|
||||
|
||||
if self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
driver_node_id, driver_gpu_ids = ray.get(
|
||||
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
||||
worker_node_and_gpu_ids = ray.get(
|
||||
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
||||
|
||||
node_workers = defaultdict(list)
|
||||
node_gpus = defaultdict(list)
|
||||
|
||||
node_workers[driver_node_id].append(0)
|
||||
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
||||
start=1):
|
||||
node_workers[node_id].append(i)
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver.
|
||||
set_cuda_visible_devices(node_gpus[driver_node_id])
|
||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||
|
||||
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
init_torch_dist_process_group(self.workers, backend="nccl")
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
self._run_workers("init_worker",
|
||||
get_all_outputs=True,
|
||||
worker_init_fn=lambda: Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
worker_node_and_gpu_ids),
|
||||
start=1):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
worker.init_worker.remote(
|
||||
lambda rank=rank, local_rank=local_rank: Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
))
|
||||
|
||||
driver_rank = 0
|
||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||
self.driver_worker = Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
self._run_workers("init_model")
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
"load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@@ -192,7 +261,6 @@ class LLMEngine:
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
get_all_outputs=True,
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
@@ -211,12 +279,23 @@ class LLMEngine:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
||||
if self.model_config.max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
# Warm up the model. This includes capturing the model into CUDA graph
|
||||
# if enforce_eager is False.
|
||||
self._run_workers("warm_up_model")
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
|
||||
@@ -225,11 +304,9 @@ class LLMEngine:
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config)
|
||||
placement_group = initialize_cluster(parallel_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs,
|
||||
distributed_init_method,
|
||||
placement_group,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
@@ -296,16 +373,6 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
List[RequestOutput]]:
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
@@ -351,7 +418,7 @@ class LLMEngine:
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutputs) -> None:
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None:
|
||||
@@ -372,7 +439,7 @@ class LLMEngine:
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
@@ -554,18 +621,23 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return ignored
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
# Execute the model.
|
||||
output = self._run_workers(
|
||||
"execute_model",
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
all_outputs = self._run_workers(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
})
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
else:
|
||||
output = []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
@@ -581,8 +653,8 @@ class LLMEngine:
|
||||
else:
|
||||
self.num_generation_tokens.append((now, num_batched_tokens))
|
||||
|
||||
elapsed_time = now - self.last_logging_time
|
||||
if elapsed_time < _LOGGING_INTERVAL_SEC:
|
||||
should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
|
||||
if not should_log:
|
||||
return
|
||||
|
||||
# Discard the old stats.
|
||||
@@ -621,6 +693,16 @@ class LLMEngine:
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
record_metrics(
|
||||
avg_prompt_throughput=avg_prompt_throughput,
|
||||
avg_generation_throughput=avg_generation_throughput,
|
||||
scheduler_running=len(self.scheduler.running),
|
||||
scheduler_swapped=len(self.scheduler.swapped),
|
||||
scheduler_waiting=len(self.scheduler.waiting),
|
||||
gpu_cache_usage=gpu_cache_usage,
|
||||
cpu_cache_usage=cpu_cache_usage,
|
||||
)
|
||||
|
||||
logger.info("Avg prompt throughput: "
|
||||
f"{avg_prompt_throughput:.1f} tokens/s, "
|
||||
"Avg generation throughput: "
|
||||
@@ -657,9 +739,10 @@ class LLMEngine:
|
||||
"""Stop the finished sequences."""
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
if not sampling_params.include_stop_str_in_output:
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||
@@ -686,28 +769,34 @@ class LLMEngine:
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = ray.get(all_outputs)
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
return output
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = getattr(self.driver_worker,
|
||||
method)(*driver_args, **driver_kwargs)
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
|
||||
51
vllm/engine/metrics.py
Normal file
51
vllm/engine/metrics.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from aioprometheus import Gauge
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the metrics definitions.
|
||||
|
||||
# begin-metrics-definitions
|
||||
gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s",
|
||||
"Average prefill throughput in tokens/s.")
|
||||
gauge_avg_generation_throughput = Gauge(
|
||||
"vllm:avg_generation_throughput_toks_per_s",
|
||||
"Average generation throughput in tokens/s.")
|
||||
|
||||
gauge_scheduler_running = Gauge(
|
||||
"vllm:num_requests_running",
|
||||
"Number of requests that is currently running for inference.")
|
||||
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped",
|
||||
"Number requests swapped to CPU.")
|
||||
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting",
|
||||
"Number of requests waiting to be processed.")
|
||||
|
||||
gauge_gpu_cache_usage = Gauge(
|
||||
"vllm:gpu_cache_usage_perc",
|
||||
"GPU KV-cache usage. 1 means 100 percent usage.")
|
||||
gauge_cpu_cache_usage = Gauge(
|
||||
"vllm:cpu_cache_usage_perc",
|
||||
"CPU KV-cache usage. 1 means 100 percent usage.")
|
||||
# end-metrics-definitions
|
||||
|
||||
labels = {}
|
||||
|
||||
|
||||
def add_global_metrics_labels(**kwargs):
|
||||
labels.update(kwargs)
|
||||
|
||||
|
||||
def record_metrics(
|
||||
avg_prompt_throughput: float,
|
||||
avg_generation_throughput: float,
|
||||
scheduler_running: int,
|
||||
scheduler_swapped: int,
|
||||
scheduler_waiting: int,
|
||||
gpu_cache_usage: float,
|
||||
cpu_cache_usage: float,
|
||||
):
|
||||
gauge_avg_prompt_throughput.set(labels, avg_prompt_throughput)
|
||||
gauge_avg_generation_throughput.set(labels, avg_generation_throughput)
|
||||
gauge_scheduler_running.set(labels, scheduler_running)
|
||||
gauge_scheduler_swapped.set(labels, scheduler_swapped)
|
||||
gauge_scheduler_waiting.set(labels, scheduler_waiting)
|
||||
gauge_gpu_cache_usage.set(labels, gpu_cache_usage)
|
||||
gauge_cpu_cache_usage.set(labels, cpu_cache_usage)
|
||||
@@ -1,22 +1,20 @@
|
||||
import socket
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip, set_cuda_visible_devices, get_ip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import ray
|
||||
from ray.air.util.torch_dist import TorchDistributedWorker
|
||||
|
||||
class RayWorker(TorchDistributedWorker):
|
||||
class RayWorkerVllm:
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||
if init_cached_hf_modules:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
init_hf_modules()
|
||||
self.worker = None
|
||||
@@ -31,24 +29,28 @@ try:
|
||||
executor = getattr(self, method)
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
|
||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
||||
node_id = ray.get_runtime_context().get_node_id()
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
return node_id, gpu_ids
|
||||
|
||||
def set_cuda_visible_devices(self, device_ids) -> None:
|
||||
set_cuda_visible_devices(device_ids)
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||
"For distributed inference, please install Ray with "
|
||||
"`pip install ray pandas pyarrow`.")
|
||||
ray = None
|
||||
TorchDistributedWorker = None
|
||||
RayWorker = None # pylint: disable=invalid-name
|
||||
RayWorkerVllm = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
|
||||
def get_open_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
engine_use_ray: bool = False,
|
||||
@@ -74,16 +76,19 @@ def initialize_cluster(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
port = get_open_port()
|
||||
# We need to setup the distributed init method to make sure
|
||||
# the distributed megatron code (e.g., get world size) works correctly.
|
||||
distributed_init_method = f"tcp://localhost:{port}"
|
||||
return distributed_init_method, None
|
||||
assert parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
return None
|
||||
|
||||
# Create placement group for worker processes
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
if current_placement_group:
|
||||
# We are in a placement group
|
||||
@@ -108,12 +113,12 @@ def initialize_cluster(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the cluster.")
|
||||
# Create a new placement group
|
||||
current_placement_group = ray.util.placement_group([{
|
||||
"GPU": 1
|
||||
}] * parallel_config.world_size)
|
||||
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
|
||||
current_placement_group = ray.util.placement_group(
|
||||
placement_group_specs)
|
||||
# Wait until PG is ready - this will block until all
|
||||
# requested resources are available, and will timeout
|
||||
# if they cannot be provisioned.
|
||||
ray.get(current_placement_group.ready(), timeout=1800)
|
||||
|
||||
return None, current_placement_group
|
||||
return current_placement_group
|
||||
|
||||
@@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
@@ -73,6 +72,8 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -83,4 +84,6 @@ if __name__ == "__main__":
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile)
|
||||
|
||||
@@ -38,8 +38,10 @@ class LLM:
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq". If None, we assume the model weights are not
|
||||
quantized and use `dtype` to determine the data type of the weights.
|
||||
we support "awq", "gptq" and "squeezellm". If None, we first check
|
||||
the `quantization_config` attribute in the model config file. If
|
||||
that is None, we assume the model weights are not quantized and use
|
||||
`dtype` to determine the data type of the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
@@ -55,6 +57,12 @@ class LLM:
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -71,6 +79,8 @@ class LLM:
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
@@ -88,6 +98,8 @@ class LLM:
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
@@ -134,25 +146,21 @@ class LLM:
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
if prompts is not None and prompt_token_ids is not None:
|
||||
if len(prompts) != len(prompt_token_ids):
|
||||
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||
"must be the same.")
|
||||
if (prompts is not None and prompt_token_ids is not None
|
||||
and len(prompts) != len(prompt_token_ids)):
|
||||
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||
"must be the same.")
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
# Add requests to the engine.
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
else:
|
||||
num_requests = len(prompt_token_ids)
|
||||
num_requests = len(prompts) if prompts is not None else len(
|
||||
prompt_token_ids)
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
if prompt_token_ids is None:
|
||||
token_ids = None
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
i]
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
|
||||
@@ -3,21 +3,24 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from aioprometheus import MetricsMiddleware
|
||||
from aioprometheus.asgi.starlette import metrics
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from packaging import version
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.metrics import add_global_metrics_labels
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||
@@ -31,20 +34,67 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
try:
|
||||
import fastchat
|
||||
from fastchat.conversation import Conversation, SeparatorStyle
|
||||
from fastchat.model.model_adapter import get_conversation_template
|
||||
_fastchat_available = True
|
||||
except ImportError:
|
||||
_fastchat_available = False
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
app = fastapi.FastAPI()
|
||||
engine = None
|
||||
response_role = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument("--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the SSL key file")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the SSL cert file")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
|
||||
app.add_route("/metrics", metrics) # Exposes HTTP metrics
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
@@ -54,8 +104,27 @@ def create_error_response(status_code: HTTPStatus,
|
||||
status_code=status_code.value)
|
||||
|
||||
|
||||
def load_chat_template(args, tokenizer):
|
||||
if args.chat_template is not None:
|
||||
try:
|
||||
with open(args.chat_template, "r") as f:
|
||||
chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
chat_template = codecs.decode(args.chat_template, "unicode_escape")
|
||||
|
||||
tokenizer.chat_template = chat_template
|
||||
logger.info(
|
||||
f"Using supplied chat template:\n{tokenizer.chat_template}")
|
||||
elif tokenizer.chat_template is not None:
|
||||
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
|
||||
async def validation_exception_handler(_, exc):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||
|
||||
|
||||
@@ -69,53 +138,6 @@ async def check_model(request) -> Optional[JSONResponse]:
|
||||
return ret
|
||||
|
||||
|
||||
async def get_gen_prompt(request) -> str:
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
"fastchat is not installed. Please install fastchat to use "
|
||||
"the chat completion and conversation APIs: `$ pip install fschat`"
|
||||
)
|
||||
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
||||
raise ImportError(
|
||||
f"fastchat version is low. Current version: {fastchat.__version__} "
|
||||
"Please upgrade fastchat to use: `$ pip install -U fschat`")
|
||||
|
||||
conv = get_conversation_template(request.model)
|
||||
conv = Conversation(
|
||||
name=conv.name,
|
||||
system_template=conv.system_template,
|
||||
system_message=conv.system_message,
|
||||
roles=conv.roles,
|
||||
messages=list(conv.messages), # prevent in-place modification
|
||||
offset=conv.offset,
|
||||
sep_style=SeparatorStyle(conv.sep_style),
|
||||
sep=conv.sep,
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
stop_token_ids=conv.stop_token_ids,
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
prompt = request.messages
|
||||
else:
|
||||
for message in request.messages:
|
||||
msg_role = message["role"]
|
||||
if msg_role == "system":
|
||||
conv.system_message = message["content"]
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message["content"])
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message["content"])
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
@@ -124,10 +146,8 @@ async def check_length(
|
||||
assert (not (prompt is None and prompt_ids is None)
|
||||
and not (prompt is not None and prompt_ids is not None)
|
||||
), "Either prompt or prompt_ids should be provided."
|
||||
if prompt_ids is not None:
|
||||
input_ids = prompt_ids
|
||||
else:
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
input_ids = prompt_ids if prompt_ids is not None else tokenizer(
|
||||
prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
@@ -162,16 +182,26 @@ async def show_available_models():
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
def create_logprobs(token_ids: List[int],
|
||||
id_logprobs: List[Dict[int, float]],
|
||||
initial_text_offset: int = 0) -> LogProbs:
|
||||
def create_logprobs(
|
||||
token_ids: List[int],
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
else:
|
||||
token_logprob = None
|
||||
token = tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(id_logprob[token_id])
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
@@ -179,10 +209,11 @@ def create_logprobs(token_ids: List[int],
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append({
|
||||
tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in id_logprob.items()
|
||||
})
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
|
||||
@@ -198,8 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
logger.info(f"Received chat completion request: {request}")
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@@ -209,7 +238,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in applying chat template from request: {str(e)}")
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@@ -217,14 +254,17 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
@@ -241,128 +281,161 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None,
|
||||
usage: Optional[UsageInfo] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=text),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
response = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
# exclude unset to leave details out of each sse
|
||||
response_json = response.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
def get_role() -> str:
|
||||
if request.add_generation_prompt:
|
||||
return response_role
|
||||
else:
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
# First chunk with role
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = get_role()
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=last_msg_content),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
completion_tokens = len(output.token_ids)
|
||||
previous_num_tokens[i] = completion_tokens
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
finish_reason=output.finish_reason,
|
||||
usage=final_usage,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=[], finish_reason=output.finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.json(exclude_unset=True,
|
||||
exclude_none=True,
|
||||
ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def completion_full_generator():
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
||||
choices = []
|
||||
role = get_role()
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role="assistant", content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
return response
|
||||
else:
|
||||
return await completion_full_generator()
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@@ -373,23 +446,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- echo (since the vLLM engine does not currently support
|
||||
getting the logprobs of prompt tokens)
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.echo:
|
||||
# We do not support echo since the vLLM engine does not
|
||||
# currently support getting the logprobs of prompt tokens.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"echo is not currently supported")
|
||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||
echo_without_generation = request.echo and request.max_tokens == 0
|
||||
|
||||
if request.suffix is not None:
|
||||
# The language models we currently support do not support suffix.
|
||||
@@ -439,15 +506,19 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
best_of=request.best_of,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens,
|
||||
max_tokens=request.max_tokens
|
||||
if not echo_without_generation else 1,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
prompt_logprobs=request.logprobs if request.echo else None,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
@@ -497,24 +568,47 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
has_echoed = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
if request.logprobs is not None:
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
else:
|
||||
top_logprobs = None
|
||||
offsets = len(previous_texts[i])
|
||||
if request.echo and not has_echoed[i]:
|
||||
if not echo_without_generation:
|
||||
delta_text = res.prompt + delta_text
|
||||
token_ids = res.prompt_token_ids + token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else: # only just return the prompt
|
||||
delta_text = res.prompt
|
||||
token_ids = res.prompt_token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
output.token_ids[previous_num_tokens[i]:],
|
||||
output.logprobs[previous_num_tokens[i]:],
|
||||
len(previous_texts[i]))
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=offsets,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
@@ -553,14 +647,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
for output in final_res.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs)
|
||||
if not echo_without_generation:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
if request.echo:
|
||||
token_ids = prompt_token_ids + token_ids
|
||||
top_logprobs = prompt_logprobs + top_logprobs
|
||||
else:
|
||||
token_ids = prompt_token_ids
|
||||
top_logprobs = prompt_logprobs
|
||||
logprobs = create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
if not echo_without_generation:
|
||||
output_text = output.text
|
||||
if request.echo:
|
||||
output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index,
|
||||
text=output.text,
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
@@ -598,34 +714,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -642,6 +731,8 @@ if __name__ == "__main__":
|
||||
else:
|
||||
served_model = args.model
|
||||
|
||||
response_role = args.response_role
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
@@ -652,9 +743,15 @@ if __name__ == "__main__":
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
load_chat_template(args, tokenizer)
|
||||
|
||||
# Register labels for metrics
|
||||
add_global_metrics_labels(model_name=engine_args.model)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile)
|
||||
|
||||
@@ -73,6 +73,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
echo: Optional[bool] = False
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@@ -100,14 +104,15 @@ class CompletionRequest(BaseModel):
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str,
|
||||
float]]] = Field(default_factory=list)
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
"SamplingMetadata",
|
||||
"set_random_seed",
|
||||
]
|
||||
|
||||
@@ -1,91 +1,44 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used for PagedAttention.
|
||||
"""Metadata for input sequences. Used in PagedAttention.
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
context_lens: the length of attention context for each generation token.
|
||||
max_context_len: The maximum context length.
|
||||
context_lens: the length of attention context for each sequence.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
prompt_lens: List[int],
|
||||
is_prompt: bool,
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
sliding_window: Optional[int] = None,
|
||||
max_context_len: Optional[int],
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
use_cuda_graph: bool,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.is_prompt = is_prompt
|
||||
self.max_context_len = max_context_len
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.to_cache = None
|
||||
if sliding_window is not None:
|
||||
# We need to keep the positions of sliding windows within
|
||||
# the key / value tables, this is helpful to know which
|
||||
# elements we need to cache.
|
||||
to_cache, start_idx = [], 0
|
||||
for prompt_len in self.prompt_lens:
|
||||
to_cache.extend(
|
||||
range(
|
||||
start_idx + max(0, prompt_len - sliding_window),
|
||||
start_idx + prompt_len,
|
||||
))
|
||||
start_idx += self.max_prompt_len
|
||||
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||
self.to_cache = torch.tensor(to_cache,
|
||||
dtype=torch.int32,
|
||||
device=self.slot_mapping.device)
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||
else:
|
||||
self.max_num_blocks_per_seq = 0
|
||||
assert block_tables.shape[0] == self.num_generation_tokens
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
|
||||
# Set during the execution of the first attention op.
|
||||
self.attn_bias: Optional[AttentionBias] = None
|
||||
# FIXME(woosuk): This is a hack.
|
||||
self.attn_bias = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Print only useful metadata.
|
||||
return (
|
||||
f'InputMetadata('
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}, '
|
||||
f'selected_token_indices={self.selected_token_indices}, '
|
||||
f'categorized_sample_indices={self.categorized_sample_indices}, '
|
||||
f'slot_mapping={self.slot_mapping})')
|
||||
return ("InputMetadata("
|
||||
f"is_prompt={self.is_prompt}, "
|
||||
f"max_context_len={self.max_context_len}, "
|
||||
f"slot_mapping={self.slot_mapping}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"use_cuda_graph={self.use_cuda_graph})")
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""Custom activation functions."""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import activation_ops
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.utils import divide
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
@@ -18,27 +24,43 @@ class SiluAndMul(nn.Module):
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
activation_ops.gelu_new(out, x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class FastGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
|
||||
@@ -51,17 +73,40 @@ class ScaledActivation(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
hidden_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
|
||||
torch.empty(intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
device="cuda"))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
@@ -76,6 +121,8 @@ def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
@@ -84,15 +131,11 @@ def get_act_fn(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if quant_config is not None:
|
||||
if act_fn_name in quant_config.get_scaled_act_names():
|
||||
if intermediate_size is None:
|
||||
raise ValueError(
|
||||
"intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(
|
||||
act_fn,
|
||||
intermediate_size,
|
||||
params_dtype=torch.get_default_dtype(),
|
||||
)
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user