Compare commits
441 Commits
v0.5.4
...
v0.6.1.pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acda0b35d0 | ||
|
|
ba77527955 | ||
|
|
6821020109 | ||
|
|
8427550488 | ||
|
|
3f79bc3d1a | ||
|
|
40c396533d | ||
|
|
5ec9c0fb3c | ||
|
|
8f44a92d85 | ||
|
|
360ddbd37e | ||
|
|
a480939e8e | ||
|
|
d31174a4e1 | ||
|
|
b61bd98f90 | ||
|
|
c16369455f | ||
|
|
019877253b | ||
|
|
551ce01078 | ||
|
|
a6c0f3658d | ||
|
|
f2e263b801 | ||
|
|
1f0c75afa9 | ||
|
|
8a23e93302 | ||
|
|
c6202daeed | ||
|
|
e56bf27741 | ||
|
|
520ca380ae | ||
|
|
7de49aa86c | ||
|
|
42ffba11ad | ||
|
|
295c4730a8 | ||
|
|
1bf2dd9df0 | ||
|
|
5a60699c45 | ||
|
|
b6c75e1cf2 | ||
|
|
b71c956deb | ||
|
|
f842a7aff1 | ||
|
|
a65cb16067 | ||
|
|
3fd2b0d21c | ||
|
|
d394787e52 | ||
|
|
775f00f81e | ||
|
|
8baa454937 | ||
|
|
73202dbe77 | ||
|
|
7015417fd4 | ||
|
|
aea02f30de | ||
|
|
0b952af458 | ||
|
|
3b7fea770f | ||
|
|
cea95dfb94 | ||
|
|
6a512a00df | ||
|
|
efcf946a15 | ||
|
|
1230263e16 | ||
|
|
e497b8aeff | ||
|
|
94144e726c | ||
|
|
1d5e397aa4 | ||
|
|
22f3a4bc6c | ||
|
|
b1f3e18958 | ||
|
|
04e7c4e771 | ||
|
|
5faedf1b62 | ||
|
|
02751a7a42 | ||
|
|
f421f3cefb | ||
|
|
8c054b7a62 | ||
|
|
6234385f4a | ||
|
|
da1a844e61 | ||
|
|
a1d874224d | ||
|
|
6cd5e5b07e | ||
|
|
c7cb5c3335 | ||
|
|
f9b4a2d415 | ||
|
|
58fcc8545a | ||
|
|
08287ef675 | ||
|
|
4ef41b8476 | ||
|
|
cfe712bf1a | ||
|
|
b962ee1470 | ||
|
|
36bf8150cc | ||
|
|
e807125936 | ||
|
|
9f68e00d27 | ||
|
|
ce2702a923 | ||
|
|
795b662cff | ||
|
|
2f707fcb35 | ||
|
|
41e95c5247 | ||
|
|
12dd715807 | ||
|
|
29f49cd6e3 | ||
|
|
23f322297f | ||
|
|
9db52eab3d | ||
|
|
1447c97e75 | ||
|
|
de80783b69 | ||
|
|
e5cab71531 | ||
|
|
baa5467547 | ||
|
|
db3bf7c991 | ||
|
|
2febcf2777 | ||
|
|
2ee45281a5 | ||
|
|
9da25a88aa | ||
|
|
8685ba1a1e | ||
|
|
288a938872 | ||
|
|
e39ebf5cf5 | ||
|
|
ba262c4e5a | ||
|
|
4624d98dbd | ||
|
|
1afc931987 | ||
|
|
e01c2beb7d | ||
|
|
32e7db2536 | ||
|
|
008cf886c9 | ||
|
|
77d9e514a2 | ||
|
|
e02ce498be | ||
|
|
561d6f8077 | ||
|
|
d1dec64243 | ||
|
|
2ad2e5608e | ||
|
|
d3311562fb | ||
|
|
ccd7207191 | ||
|
|
855c262a6b | ||
|
|
2be8ec6e71 | ||
|
|
e16fa99a6a | ||
|
|
61f4a93d14 | ||
|
|
d4db9f53c8 | ||
|
|
2188a60c7e | ||
|
|
dc0b6066ab | ||
|
|
0af3abe3d3 | ||
|
|
f1575dc99f | ||
|
|
c02638efb3 | ||
|
|
652c83b697 | ||
|
|
6d646d08a2 | ||
|
|
95a178f861 | ||
|
|
bd852f2a8b | ||
|
|
ec266536b7 | ||
|
|
0fbc6696c2 | ||
|
|
6e36f4fa6c | ||
|
|
dd2a6a82e3 | ||
|
|
4ca65a9763 | ||
|
|
e2b2aa5a0f | ||
|
|
e6a26ed037 | ||
|
|
f8d60145b4 | ||
|
|
5b86b19954 | ||
|
|
5231f0898e | ||
|
|
8423aef4c8 | ||
|
|
4f5d8446ed | ||
|
|
d05f0a9db2 | ||
|
|
622f8abff8 | ||
|
|
1248e8506a | ||
|
|
2684efc467 | ||
|
|
058344f89a | ||
|
|
98cef6a227 | ||
|
|
f97be32d1d | ||
|
|
afd39a4511 | ||
|
|
2148441fd3 | ||
|
|
dc13e99348 | ||
|
|
34a0e96d46 | ||
|
|
80c7b089b1 | ||
|
|
428dd1445e | ||
|
|
4abed65c58 | ||
|
|
0c785d344d | ||
|
|
4664ceaad6 | ||
|
|
257afc37c5 | ||
|
|
86a677de42 | ||
|
|
d78789ac16 | ||
|
|
c334b1898b | ||
|
|
6b3421567d | ||
|
|
3f60f2244e | ||
|
|
f205c09854 | ||
|
|
ef99a78760 | ||
|
|
74d5543ec5 | ||
|
|
a7f65c2be9 | ||
|
|
4289cad37f | ||
|
|
af59df0a10 | ||
|
|
ce6bf3a2cf | ||
|
|
3cdfe1f38b | ||
|
|
fdd9daafa3 | ||
|
|
8c56e57def | ||
|
|
eeffde1ac0 | ||
|
|
e5697d161c | ||
|
|
b98cc28f91 | ||
|
|
ef9baee3c5 | ||
|
|
98c12cffe5 | ||
|
|
f52a43a8b9 | ||
|
|
e3580537a4 | ||
|
|
f508e03e7f | ||
|
|
51f86bf487 | ||
|
|
c166e7e43e | ||
|
|
bc6e42a9b1 | ||
|
|
fab5f53e2d | ||
|
|
9c71c97ae2 | ||
|
|
5340a2dccf | ||
|
|
345be0e244 | ||
|
|
fc911880cc | ||
|
|
ed6f002d33 | ||
|
|
b09c755be8 | ||
|
|
42e932c7d4 | ||
|
|
076169f603 | ||
|
|
9db642138b | ||
|
|
6fc4e6e07a | ||
|
|
9606c7197d | ||
|
|
64cc644425 | ||
|
|
39178c7fbc | ||
|
|
2eedede875 | ||
|
|
015e6cc252 | ||
|
|
760e9f71a8 | ||
|
|
05826c887b | ||
|
|
dd9857f5fa | ||
|
|
665304092d | ||
|
|
2deb029d11 | ||
|
|
029c71de11 | ||
|
|
0b769992ec | ||
|
|
1856aff4d6 | ||
|
|
70c094ade6 | ||
|
|
2059b8d9ca | ||
|
|
8aaf3d5347 | ||
|
|
80162c44b1 | ||
|
|
aab0fcdb63 | ||
|
|
ea9fa160e3 | ||
|
|
7d9ffa2ae1 | ||
|
|
d81abefd2e | ||
|
|
8da48e4d95 | ||
|
|
6885fde317 | ||
|
|
9db93de20c | ||
|
|
09c7792610 | ||
|
|
f1df5dbfd6 | ||
|
|
35ee2ad6b9 | ||
|
|
e25fee57c2 | ||
|
|
faeddb565d | ||
|
|
fc5ebbd1d3 | ||
|
|
c01a6cb231 | ||
|
|
b903e1ba7f | ||
|
|
a152246428 | ||
|
|
666ad0aa16 | ||
|
|
15310b5101 | ||
|
|
57792ed469 | ||
|
|
d3b5b98021 | ||
|
|
cc0eaf12b1 | ||
|
|
955b5191c9 | ||
|
|
55d63b1211 | ||
|
|
4f419c00a6 | ||
|
|
a3fce56b88 | ||
|
|
b3856bef7d | ||
|
|
8c6f694a79 | ||
|
|
eeee1c3b1a | ||
|
|
aae74ef95c | ||
|
|
cde9183b40 | ||
|
|
df1a21131d | ||
|
|
7937009a7e | ||
|
|
9984605412 | ||
|
|
7eebe8ccaa | ||
|
|
8678a69ab5 | ||
|
|
5844017285 | ||
|
|
1ca0d4f86b | ||
|
|
dd53c4b023 | ||
|
|
970dfdc01d | ||
|
|
91f4522cbf | ||
|
|
1b32e02648 | ||
|
|
f7e3b0c5aa | ||
|
|
d3c002eadc | ||
|
|
9b73a2f498 | ||
|
|
6925cdbeea | ||
|
|
53328d7536 | ||
|
|
c75363fbc0 | ||
|
|
dd3fa0e430 | ||
|
|
baaedfdb2d | ||
|
|
4506641212 | ||
|
|
12e1c65bc9 | ||
|
|
b74a125800 | ||
|
|
66a9e713a7 | ||
|
|
9e51b6a626 | ||
|
|
6e4658c7aa | ||
|
|
3b682179dd | ||
|
|
c6af027a35 | ||
|
|
2aa00d59ad | ||
|
|
c42590f97a | ||
|
|
aae6927be0 | ||
|
|
398521ad19 | ||
|
|
5288c06aa0 | ||
|
|
b6f99a6ffe | ||
|
|
ad28a74beb | ||
|
|
e6d811dd13 | ||
|
|
c4be16e1a7 | ||
|
|
3d8a5f063d | ||
|
|
f4fc7337bf | ||
|
|
0df7ec0b2d | ||
|
|
312f761232 | ||
|
|
e54ebc2f8f | ||
|
|
67e02fa8a4 | ||
|
|
43735bf5e1 | ||
|
|
da115230fd | ||
|
|
7601cb044d | ||
|
|
47b65a5508 | ||
|
|
dad961ef5c | ||
|
|
3ac50b47d0 | ||
|
|
df845b2b46 | ||
|
|
1a36287b89 | ||
|
|
f710fb5265 | ||
|
|
ff7ec82c4d | ||
|
|
200a2ffa6b | ||
|
|
40e1360bb6 | ||
|
|
e3b318216d | ||
|
|
ab7165f2c7 | ||
|
|
0c2fa50b84 | ||
|
|
ce143353c6 | ||
|
|
bbf55c4805 | ||
|
|
1ef13cf92f | ||
|
|
832163b875 | ||
|
|
e73f76eec6 | ||
|
|
d95cc0a55c | ||
|
|
5bf45db7df | ||
|
|
eed020f673 | ||
|
|
7c0b7ea214 | ||
|
|
4706eb628e | ||
|
|
bae888cb8e | ||
|
|
6bd19551b0 | ||
|
|
e680349994 | ||
|
|
44f26a9466 | ||
|
|
37fd47e780 | ||
|
|
7759ae958f | ||
|
|
9f69856356 | ||
|
|
d4f0f17b02 | ||
|
|
b3f4e17935 | ||
|
|
93478b63d2 | ||
|
|
f366f6339b | ||
|
|
855866caa9 | ||
|
|
7fc23be81c | ||
|
|
e837b624f2 | ||
|
|
ec724a725e | ||
|
|
0e39a33c6d | ||
|
|
6fc5b0f249 | ||
|
|
9587b050fb | ||
|
|
54bd9a03c4 | ||
|
|
50b8d08dbd | ||
|
|
e165528778 | ||
|
|
3b19e39dc5 | ||
|
|
4cd7d47fed | ||
|
|
f878c8feb0 | ||
|
|
b67ae00cdb | ||
|
|
9c8e2d1161 | ||
|
|
21313e09e3 | ||
|
|
f4da5f7b6d | ||
|
|
9c1f78d5d6 | ||
|
|
fc93e56143 | ||
|
|
22b39e11f2 | ||
|
|
f55a9aea45 | ||
|
|
951fdd66d3 | ||
|
|
2ecf7b1757 | ||
|
|
3f674a49b5 | ||
|
|
70b746efcf | ||
|
|
67d115db08 | ||
|
|
d3d9cb6e4b | ||
|
|
c134a46402 | ||
|
|
199adbb7cf | ||
|
|
dd164d72f3 | ||
|
|
ea49e6a3c8 | ||
|
|
97992802f3 | ||
|
|
59edd0f134 | ||
|
|
a08df8322e | ||
|
|
16422ea76f | ||
|
|
373538f973 | ||
|
|
33e5d7e6b6 | ||
|
|
c5c7768264 | ||
|
|
b1e5afc3e7 | ||
|
|
d3bdfd3ab9 | ||
|
|
fb377d7e74 | ||
|
|
181abbc27d | ||
|
|
00c3d68e45 | ||
|
|
e20233d361 | ||
|
|
d6e634f3d7 | ||
|
|
4d2dc5072b | ||
|
|
7025b11d94 | ||
|
|
5469146bcc | ||
|
|
97a6be95ba | ||
|
|
9ba85bc152 | ||
|
|
198d6a2898 | ||
|
|
774cd1d3bf | ||
|
|
91294d56e1 | ||
|
|
a046f86397 | ||
|
|
4ddc4743d7 | ||
|
|
6aa33cb2dd | ||
|
|
1137f343aa | ||
|
|
9b3e2edd30 | ||
|
|
65950e8f58 | ||
|
|
cfba4def5d | ||
|
|
d2bc4510a4 | ||
|
|
24154f8618 | ||
|
|
e6e42e4b17 | ||
|
|
ec2affa8ae | ||
|
|
86ab567bae | ||
|
|
f020a6297e | ||
|
|
6c8e595710 | ||
|
|
02b1988b9f | ||
|
|
386087970a | ||
|
|
c08e2b3086 | ||
|
|
4fb7b52a2c | ||
|
|
90bab18f24 | ||
|
|
4c5d8e8ea9 | ||
|
|
baa240252e | ||
|
|
999ef0b917 | ||
|
|
5c6c54d67a | ||
|
|
933790c209 | ||
|
|
70d268a399 | ||
|
|
249b88228d | ||
|
|
74af2bbd90 | ||
|
|
fc7b8d1eef | ||
|
|
67abdbb42f | ||
|
|
07ab160741 | ||
|
|
b4e9528f95 | ||
|
|
57b7be0e1c | ||
|
|
99b4cf5f23 | ||
|
|
e02ac55617 | ||
|
|
73388c07a4 | ||
|
|
7eb4a51c5f | ||
|
|
0fa14907da | ||
|
|
5923532e15 | ||
|
|
a049b107e2 | ||
|
|
8334c39f37 | ||
|
|
e904576743 | ||
|
|
e14fb22e59 | ||
|
|
782e53ab59 | ||
|
|
21b9c49aa3 | ||
|
|
5fb4a3f678 | ||
|
|
757ac70a64 | ||
|
|
6dffa4b0a6 | ||
|
|
48abee9e54 | ||
|
|
746709642c | ||
|
|
e53dfd3eaf | ||
|
|
6d94420246 | ||
|
|
fc1493a01e | ||
|
|
311f743831 | ||
|
|
469b3bc538 | ||
|
|
5223199e03 | ||
|
|
fde47d3bc2 | ||
|
|
0e12cd67a8 | ||
|
|
80cbe10c59 | ||
|
|
b764547616 | ||
|
|
ab0f5e2823 | ||
|
|
564985729a | ||
|
|
0f7052bc7e | ||
|
|
639159b2a6 | ||
|
|
66d617e343 | ||
|
|
7b261092de | ||
|
|
2385c8f374 | ||
|
|
9a3f49ae07 | ||
|
|
f9a5600649 | ||
|
|
fd95e026e0 | ||
|
|
660470e5a3 | ||
|
|
8d59dbb000 | ||
|
|
5c60c8c423 | ||
|
|
00afc78590 | ||
|
|
541c1852d3 | ||
|
|
a3bbbfa1d8 | ||
|
|
1f26efbb3a | ||
|
|
9118217f58 | ||
|
|
e3c664bfcb | ||
|
|
360bd67cf0 | ||
|
|
ef527be06c | ||
|
|
89b8db6bb2 | ||
|
|
789937af2e | ||
|
|
dfb1a15dcb |
@@ -1,36 +1,43 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
MAX_SIZE_MB = 250
|
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
|
||||||
|
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
|
||||||
|
|
||||||
|
|
||||||
def print_top_10_largest_files(zip_file):
|
def print_top_10_largest_files(zip_file):
|
||||||
|
"""Print the top 10 largest files in the given zip file."""
|
||||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||||
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
||||||
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
||||||
for f, size in file_sizes[:10]:
|
for f, size in file_sizes[:10]:
|
||||||
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
|
print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.")
|
||||||
|
|
||||||
|
|
||||||
def check_wheel_size(directory):
|
def check_wheel_size(directory):
|
||||||
|
"""Check the size of .whl files in the given directory."""
|
||||||
for root, _, files in os.walk(directory):
|
for root, _, files in os.walk(directory):
|
||||||
for f in files:
|
for file_name in files:
|
||||||
if f.endswith(".whl"):
|
if file_name.endswith(".whl"):
|
||||||
wheel_path = os.path.join(root, f)
|
wheel_path = os.path.join(root, file_name)
|
||||||
wheel_size = os.path.getsize(wheel_path)
|
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
||||||
wheel_size_mb = wheel_size / (1024 * 1024)
|
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
||||||
if wheel_size_mb > MAX_SIZE_MB:
|
print(f"Not allowed: Wheel {wheel_path} is larger "
|
||||||
print(
|
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||||
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
|
f"({VLLM_MAX_SIZE_MB} MB).")
|
||||||
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
|
|
||||||
print_top_10_largest_files(wheel_path)
|
print_top_10_largest_files(wheel_path)
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
print(f"Wheel {wheel_path} is within the allowed size "
|
print(f"Wheel {wheel_path} is within the allowed size "
|
||||||
f"({wheel_size_mb} MB).")
|
f"({wheel_size_mb:.2f} MB).")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
if len(sys.argv) < 2:
|
||||||
sys.exit(check_wheel_size(sys.argv[1]))
|
print("Usage: python check-wheel-size.py <directory>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
directory = sys.argv[1]
|
||||||
|
sys.exit(check_wheel_size(directory))
|
||||||
@@ -9,3 +9,4 @@ tasks:
|
|||||||
value: 0.664
|
value: 0.664
|
||||||
limit: 1000
|
limit: 1000
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
|
trust_remote_code: True
|
||||||
@@ -4,8 +4,8 @@ tasks:
|
|||||||
- name: "gsm8k"
|
- name: "gsm8k"
|
||||||
metrics:
|
metrics:
|
||||||
- name: "exact_match,strict-match"
|
- name: "exact_match,strict-match"
|
||||||
value: 0.409
|
value: 0.419
|
||||||
- name: "exact_match,flexible-extract"
|
- name: "exact_match,flexible-extract"
|
||||||
value: 0.406
|
value: 0.416
|
||||||
limit: 1000
|
limit: 1000
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
|
||||||
model_name: "nvidia/Minitron-4B-Base"
|
model_name: "mgoin/Minitron-4B-Base-FP8"
|
||||||
tasks:
|
tasks:
|
||||||
- name: "gsm8k"
|
- name: "gsm8k"
|
||||||
metrics:
|
metrics:
|
||||||
- name: "exact_match,strict-match"
|
- name: "exact_match,strict-match"
|
||||||
value: 0.252
|
value: 0.233
|
||||||
- name: "exact_match,flexible-extract"
|
- name: "exact_match,flexible-extract"
|
||||||
value: 0.252
|
value: 0.236
|
||||||
limit: 1000
|
limit: 1000
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
Meta-Llama-3-8B-Instruct.yaml
|
Meta-Llama-3-8B-Instruct.yaml
|
||||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
|
||||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||||
Minitron-4B-Base.yaml
|
Minitron-4B-Base-FP8.yaml
|
||||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||||
Meta-Llama-3-8B-QQQ.yaml
|
Meta-Llama-3-8B-QQQ.yaml
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import lm_eval
|
|||||||
import numpy
|
import numpy
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
RTOL = 0.02
|
RTOL = 0.05
|
||||||
TEST_DATA_FILE = os.environ.get(
|
TEST_DATA_FILE = os.environ.get(
|
||||||
"LM_EVAL_TEST_DATA_FILE",
|
"LM_EVAL_TEST_DATA_FILE",
|
||||||
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
|
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
|
||||||
@@ -23,9 +23,12 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
|
|||||||
|
|
||||||
|
|
||||||
def launch_lm_eval(eval_config):
|
def launch_lm_eval(eval_config):
|
||||||
|
trust_remote_code = eval_config.get('trust_remote_code', False)
|
||||||
|
|
||||||
model_args = f"pretrained={eval_config['model_name']}," \
|
model_args = f"pretrained={eval_config['model_name']}," \
|
||||||
f"tensor_parallel_size={TP_SIZE}," \
|
f"tensor_parallel_size={TP_SIZE}," \
|
||||||
f"add_bos_token=true"
|
f"add_bos_token=true," \
|
||||||
|
f"trust_remote_code={trust_remote_code}"
|
||||||
|
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model="vllm",
|
model="vllm",
|
||||||
|
|||||||
@@ -34,17 +34,18 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan
|
|||||||
|
|
||||||
Performance benchmark will be triggered when:
|
Performance benchmark will be triggered when:
|
||||||
- A PR being merged into vllm.
|
- A PR being merged into vllm.
|
||||||
- Every commit for those PRs with `perf-benchmarks` label.
|
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||||
|
|
||||||
Nightly benchmark will be triggered when:
|
Nightly benchmark will be triggered when:
|
||||||
- Every commit for those PRs with `nightly-benchmarks` label.
|
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Performance benchmark details
|
## Performance benchmark details
|
||||||
|
|
||||||
See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
|
||||||
|
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||||
|
|
||||||
|
|
||||||
#### Latency test
|
#### Latency test
|
||||||
@@ -68,7 +69,7 @@ Here is an example of one test inside `latency-tests.json`:
|
|||||||
|
|
||||||
In this example:
|
In this example:
|
||||||
- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`.
|
- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`.
|
||||||
- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
|
- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
|
||||||
|
|
||||||
Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly.
|
Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly.
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ steps:
|
|||||||
containers:
|
containers:
|
||||||
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
||||||
command:
|
command:
|
||||||
- bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
|
- bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
nvidia.com/gpu: 8
|
nvidia.com/gpu: 8
|
||||||
|
|||||||
@@ -1,47 +1,42 @@
|
|||||||
|
|
||||||
## Latency tests
|
## Latency tests
|
||||||
|
|
||||||
This test suite aims to test vllm's end-to-end latency under a controlled setup.
|
|
||||||
|
|
||||||
- Input length: 32 tokens.
|
- Input length: 32 tokens.
|
||||||
- Output length: 128 tokens.
|
- Output length: 128 tokens.
|
||||||
- Batch size: fixed (8).
|
- Batch size: fixed (8).
|
||||||
- Models: llama-3 8B, llama-3 70B, mixtral 8x7B.
|
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||||
|
|
||||||
### Latency benchmarking results
|
|
||||||
|
|
||||||
{latency_tests_markdown_table}
|
{latency_tests_markdown_table}
|
||||||
|
|
||||||
## Throughput tests
|
|
||||||
|
|
||||||
This test suite aims to test vllm's throughput.
|
## Throughput tests
|
||||||
|
|
||||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||||
- Output length: the corresponding output length of these 200 prompts.
|
- Output length: the corresponding output length of these 200 prompts.
|
||||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||||
- Models: llama-3 8B, llama-3 70B, mixtral 8x7B.
|
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||||
- Evaluation metrics: throughput.
|
- Evaluation metrics: throughput.
|
||||||
|
|
||||||
### Throughput benchmarking results
|
|
||||||
|
|
||||||
{throughput_tests_markdown_table}
|
{throughput_tests_markdown_table}
|
||||||
|
|
||||||
## Serving tests
|
|
||||||
|
|
||||||
This test suite aims to test vllm's real serving metrics.
|
## Serving tests
|
||||||
|
|
||||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||||
- Output length: the corresponding output length of these 200 prompts.
|
- Output length: the corresponding output length of these 200 prompts.
|
||||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||||
- Models: llama-3 8B, llama-3 70B, mixtral 8x7B.
|
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||||
|
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||||
|
|
||||||
### Serving benchmarking results
|
|
||||||
|
|
||||||
{serving_tests_markdown_table}
|
{serving_tests_markdown_table}
|
||||||
|
|
||||||
|
|
||||||
## json version of the benchmarking tables
|
## json version of the benchmarking tables
|
||||||
|
|
||||||
This section contains the data of the markdown tables above in JSON format.
|
This section contains the data of the markdown tables above in JSON format.
|
||||||
@@ -174,8 +174,8 @@ if __name__ == "__main__":
|
|||||||
# document the result
|
# document the result
|
||||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||||
|
|
||||||
results = read_markdown(
|
results = read_markdown("../.buildkite/nightly-benchmarks/" +
|
||||||
"../.buildkite/nightly-benchmarks/tests/descriptions.md")
|
"performance-benchmarks-descriptions.md")
|
||||||
results = results.format(
|
results = results.format(
|
||||||
latency_tests_markdown_table=latency_md_table,
|
latency_tests_markdown_table=latency_md_table,
|
||||||
throughput_tests_markdown_table=throughput_md_table,
|
throughput_tests_markdown_table=throughput_md_table,
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ check_hf_token() {
|
|||||||
ensure_sharegpt_downloaded() {
|
ensure_sharegpt_downloaded() {
|
||||||
local FILE=ShareGPT_V3_unfiltered_cleaned_split.json
|
local FILE=ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
if [ ! -f "$FILE" ]; then
|
if [ ! -f "$FILE" ]; then
|
||||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE
|
||||||
else
|
else
|
||||||
echo "$FILE already exists."
|
echo "$FILE already exists."
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,35 +68,38 @@ wait_for_server() {
|
|||||||
done' && return 0 || return 1
|
done' && return 0 || return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
kill_gpu_processes() {
|
kill_processes_launched_by_current_bash() {
|
||||||
# kill all processes on GPU.
|
# Kill all python processes launched from current bash script
|
||||||
pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)
|
current_shell_pid=$$
|
||||||
if [ -z "$pids" ]; then
|
processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}')
|
||||||
echo "No GPU processes found."
|
if [ -n "$processes" ]; then
|
||||||
|
echo "Killing the following processes matching '$1':"
|
||||||
|
echo "$processes"
|
||||||
|
echo "$processes" | xargs kill -9
|
||||||
else
|
else
|
||||||
for pid in $pids; do
|
echo "No processes found matching '$1'."
|
||||||
kill -9 "$pid"
|
|
||||||
echo "Killed process with PID: $pid"
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "All GPU processes have been killed."
|
|
||||||
fi
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
# waiting for GPU processes to be fully killed
|
kill_gpu_processes() {
|
||||||
# loop while nvidia-smi returns any processes
|
|
||||||
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
|
ps -aux
|
||||||
|
lsof -t -i:8000 | xargs -r kill -9
|
||||||
|
pkill -f pt_main_thread
|
||||||
|
# this line doesn't work now
|
||||||
|
# ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9
|
||||||
|
pkill -f python3
|
||||||
|
pkill -f /usr/bin/python3
|
||||||
|
|
||||||
|
|
||||||
|
# wait until GPU memory usage smaller than 1GB
|
||||||
|
while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do
|
||||||
sleep 1
|
sleep 1
|
||||||
echo "Waiting for GPU processes to be killed"
|
|
||||||
done
|
done
|
||||||
|
|
||||||
# remove vllm config file
|
# remove vllm config file
|
||||||
rm -rf ~/.config/vllm
|
rm -rf ~/.config/vllm
|
||||||
|
|
||||||
# Print the GPU memory usage
|
|
||||||
# so that we know if all GPU processes are killed.
|
|
||||||
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
|
|
||||||
# The memory usage should be 0 MB.
|
|
||||||
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
upload_to_buildkite() {
|
upload_to_buildkite() {
|
||||||
@@ -114,7 +117,7 @@ upload_to_buildkite() {
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Use the determined command to annotate and upload artifacts
|
# Use the determined command to annotate and upload artifacts
|
||||||
$BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md
|
$BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md
|
||||||
$BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*"
|
$BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +169,7 @@ run_latency_tests() {
|
|||||||
latency_command: $latency,
|
latency_command: $latency,
|
||||||
gpu_type: $gpu
|
gpu_type: $gpu
|
||||||
}')
|
}')
|
||||||
echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands"
|
echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands"
|
||||||
|
|
||||||
# run the benchmark
|
# run the benchmark
|
||||||
eval "$latency_command"
|
eval "$latency_command"
|
||||||
@@ -176,7 +179,6 @@ run_latency_tests() {
|
|||||||
done
|
done
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
run_throughput_tests() {
|
run_throughput_tests() {
|
||||||
# run throughput tests using `benchmark_throughput.py`
|
# run throughput tests using `benchmark_throughput.py`
|
||||||
# $1: a json file specifying throughput test cases
|
# $1: a json file specifying throughput test cases
|
||||||
@@ -224,7 +226,7 @@ run_throughput_tests() {
|
|||||||
throughput_command: $command,
|
throughput_command: $command,
|
||||||
gpu_type: $gpu
|
gpu_type: $gpu
|
||||||
}')
|
}')
|
||||||
echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands"
|
echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands"
|
||||||
|
|
||||||
# run the benchmark
|
# run the benchmark
|
||||||
eval "$throughput_command"
|
eval "$throughput_command"
|
||||||
@@ -256,7 +258,6 @@ run_serving_tests() {
|
|||||||
continue
|
continue
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
# get client and server arguments
|
# get client and server arguments
|
||||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||||
@@ -334,7 +335,7 @@ run_serving_tests() {
|
|||||||
client_command: $client,
|
client_command: $client,
|
||||||
gpu_type: $gpu
|
gpu_type: $gpu
|
||||||
}')
|
}')
|
||||||
echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands"
|
echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands"
|
||||||
|
|
||||||
done
|
done
|
||||||
|
|
||||||
@@ -351,6 +352,7 @@ main() {
|
|||||||
# dependencies
|
# dependencies
|
||||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
(which jq) || (apt-get update && apt-get -y install jq)
|
(which jq) || (apt-get update && apt-get -y install jq)
|
||||||
|
(which lsof) || (apt-get update && apt-get install -y lsof)
|
||||||
|
|
||||||
# get the current IP address, required by benchmark_serving.py
|
# get the current IP address, required by benchmark_serving.py
|
||||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||||
@@ -369,7 +371,6 @@ main() {
|
|||||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||||
|
|
||||||
|
|
||||||
# postprocess benchmarking results
|
# postprocess benchmarking results
|
||||||
pip install tabulate pandas
|
pip install tabulate pandas
|
||||||
python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py
|
python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
{
|
{
|
||||||
"test_name": "latency_llama8B_tp1",
|
"test_name": "latency_llama8B_tp1",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-8B",
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"tensor_parallel_size": 1,
|
"tensor_parallel_size": 1,
|
||||||
"load_format": "dummy",
|
"load_format": "dummy",
|
||||||
"num_iters_warmup": 5,
|
"num_iters_warmup": 5,
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
{
|
{
|
||||||
"test_name": "latency_llama70B_tp4",
|
"test_name": "latency_llama70B_tp4",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"tensor_parallel_size": 4,
|
"tensor_parallel_size": 4,
|
||||||
"load_format": "dummy",
|
"load_format": "dummy",
|
||||||
"num-iters-warmup": 5,
|
"num-iters-warmup": 5,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||||
"qps_list": [1, 4, 16, "inf"],
|
"qps_list": [1, 4, 16, "inf"],
|
||||||
"server_parameters": {
|
"server_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-8B",
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"tensor_parallel_size": 1,
|
"tensor_parallel_size": 1,
|
||||||
"swap_space": 16,
|
"swap_space": 16,
|
||||||
"disable_log_stats": "",
|
"disable_log_stats": "",
|
||||||
@@ -11,7 +11,7 @@
|
|||||||
"load_format": "dummy"
|
"load_format": "dummy"
|
||||||
},
|
},
|
||||||
"client_parameters": {
|
"client_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-8B",
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
@@ -22,7 +22,7 @@
|
|||||||
"test_name": "serving_llama70B_tp4_sharegpt",
|
"test_name": "serving_llama70B_tp4_sharegpt",
|
||||||
"qps_list": [1, 4, 16, "inf"],
|
"qps_list": [1, 4, 16, "inf"],
|
||||||
"server_parameters": {
|
"server_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"tensor_parallel_size": 4,
|
"tensor_parallel_size": 4,
|
||||||
"swap_space": 16,
|
"swap_space": 16,
|
||||||
"disable_log_stats": "",
|
"disable_log_stats": "",
|
||||||
@@ -30,7 +30,7 @@
|
|||||||
"load_format": "dummy"
|
"load_format": "dummy"
|
||||||
},
|
},
|
||||||
"client_parameters": {
|
"client_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
@@ -60,7 +60,7 @@
|
|||||||
"test_name": "serving_llama70B_tp4_sharegpt_specdecode",
|
"test_name": "serving_llama70B_tp4_sharegpt_specdecode",
|
||||||
"qps_list": [2],
|
"qps_list": [2],
|
||||||
"server_parameters": {
|
"server_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"disable_log_requests": "",
|
"disable_log_requests": "",
|
||||||
"tensor_parallel_size": 4,
|
"tensor_parallel_size": 4,
|
||||||
"swap_space": 16,
|
"swap_space": 16,
|
||||||
@@ -70,7 +70,7 @@
|
|||||||
"use_v2_block_manager": ""
|
"use_v2_block_manager": ""
|
||||||
},
|
},
|
||||||
"client_parameters": {
|
"client_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
{
|
{
|
||||||
"test_name": "throughput_llama8B_tp1",
|
"test_name": "throughput_llama8B_tp1",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-8B",
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"tensor_parallel_size": 1,
|
"tensor_parallel_size": 1,
|
||||||
"load_format": "dummy",
|
"load_format": "dummy",
|
||||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
{
|
{
|
||||||
"test_name": "throughput_llama70B_tp4",
|
"test_name": "throughput_llama70B_tp4",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"tensor_parallel_size": 4,
|
"tensor_parallel_size": 4,
|
||||||
"load_format": "dummy",
|
"load_format": "dummy",
|
||||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
|||||||
@@ -1,9 +1,27 @@
|
|||||||
steps:
|
steps:
|
||||||
- label: "Build wheel - CUDA {{matrix.cuda_version}}"
|
- label: "Build wheel - CUDA 12.1"
|
||||||
agents:
|
agents:
|
||||||
queue: cpu_queue
|
queue: cpu_queue
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||||
|
- "mkdir artifacts"
|
||||||
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
|
# rename the files to change linux -> manylinux1
|
||||||
|
- "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
|
||||||
|
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||||
|
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|
||||||
|
- block: "Build CUDA 11.8 wheel"
|
||||||
|
key: block-build-cu118-wheel
|
||||||
|
|
||||||
|
- label: "Build wheel - CUDA 11.8"
|
||||||
|
depends_on: block-build-cu118-wheel
|
||||||
|
agents:
|
||||||
|
queue: cpu_queue
|
||||||
|
commands:
|
||||||
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
# rename the files to change linux -> manylinux1
|
# rename the files to change linux -> manylinux1
|
||||||
@@ -12,8 +30,3 @@ steps:
|
|||||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
matrix:
|
|
||||||
setup:
|
|
||||||
cuda_version:
|
|
||||||
- "11.8.0"
|
|
||||||
- "12.1.0"
|
|
||||||
|
|||||||
69
.buildkite/run-amd-test.sh
Normal file → Executable file
69
.buildkite/run-amd-test.sh
Normal file → Executable file
@@ -1,5 +1,5 @@
|
|||||||
# This script runs test inside the corresponding ROCm docker container.
|
# This script runs test inside the corresponding ROCm docker container.
|
||||||
set -ex
|
set -o pipefail
|
||||||
|
|
||||||
# Print ROCm version
|
# Print ROCm version
|
||||||
echo "--- Confirming Clean Initial State"
|
echo "--- Confirming Clean Initial State"
|
||||||
@@ -70,15 +70,74 @@ HF_CACHE="$(realpath ~)/huggingface"
|
|||||||
mkdir -p ${HF_CACHE}
|
mkdir -p ${HF_CACHE}
|
||||||
HF_MOUNT="/root/.cache/huggingface"
|
HF_MOUNT="/root/.cache/huggingface"
|
||||||
|
|
||||||
docker run \
|
commands=$@
|
||||||
|
echo "Commands:$commands"
|
||||||
|
#ignore certain kernels tests
|
||||||
|
if [[ $commands == *" kernels "* ]]; then
|
||||||
|
commands="${commands} \
|
||||||
|
--ignore=kernels/test_attention.py \
|
||||||
|
--ignore=kernels/test_attention_selector.py \
|
||||||
|
--ignore=kernels/test_blocksparse_attention.py \
|
||||||
|
--ignore=kernels/test_causal_conv1d.py \
|
||||||
|
--ignore=kernels/test_cutlass.py \
|
||||||
|
--ignore=kernels/test_encoder_decoder_attn.py \
|
||||||
|
--ignore=kernels/test_flash_attn.py \
|
||||||
|
--ignore=kernels/test_flashinfer.py \
|
||||||
|
--ignore=kernels/test_int8_quant.py \
|
||||||
|
--ignore=kernels/test_machete_gemm.py \
|
||||||
|
--ignore=kernels/test_mamba_ssm.py \
|
||||||
|
--ignore=kernels/test_marlin_gemm.py \
|
||||||
|
--ignore=kernels/test_moe.py \
|
||||||
|
--ignore=kernels/test_prefix_prefill.py \
|
||||||
|
--ignore=kernels/test_rand.py \
|
||||||
|
--ignore=kernels/test_sampler.py"
|
||||||
|
fi
|
||||||
|
|
||||||
|
PARALLEL_JOB_COUNT=8
|
||||||
|
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||||
|
if [[ $commands == *"--shard-id="* ]]; then
|
||||||
|
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
|
||||||
|
#replace shard arguments
|
||||||
|
commands=${commands//"--shard-id= "/"--shard-id=${GPU} "}
|
||||||
|
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
|
||||||
|
echo "Shard ${GPU} commands:$commands"
|
||||||
|
docker run \
|
||||||
--device /dev/kfd --device /dev/dri \
|
--device /dev/kfd --device /dev/dri \
|
||||||
--network host \
|
--network host \
|
||||||
--shm-size=16gb \
|
--shm-size=16gb \
|
||||||
--rm \
|
--rm \
|
||||||
|
-e HIP_VISIBLE_DEVICES=${GPU} \
|
||||||
-e HF_TOKEN \
|
-e HF_TOKEN \
|
||||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||||
-e HF_HOME=${HF_MOUNT} \
|
-e HF_HOME=${HF_MOUNT} \
|
||||||
--name ${container_name} \
|
--name ${container_name}_${GPU} \
|
||||||
${image_name} \
|
${image_name} \
|
||||||
/bin/bash -c "${@}"
|
/bin/bash -c "${commands}" \
|
||||||
|
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
|
||||||
|
PIDS+=($!)
|
||||||
|
done
|
||||||
|
#wait for all processes to finish and collect exit codes
|
||||||
|
for pid in ${PIDS[@]}; do
|
||||||
|
wait ${pid}
|
||||||
|
STATUS+=($?)
|
||||||
|
done
|
||||||
|
for st in ${STATUS[@]}; do
|
||||||
|
if [[ ${st} -ne 0 ]]; then
|
||||||
|
echo "One of the processes failed with $st"
|
||||||
|
exit ${st}
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
else
|
||||||
|
docker run \
|
||||||
|
--device /dev/kfd --device /dev/dri \
|
||||||
|
--network host \
|
||||||
|
--shm-size=16gb \
|
||||||
|
--rm \
|
||||||
|
-e HIP_VISIBLE_DEVICES=0 \
|
||||||
|
-e HF_TOKEN \
|
||||||
|
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||||
|
-e HF_HOME=${HF_MOUNT} \
|
||||||
|
--name ${container_name} \
|
||||||
|
${image_name} \
|
||||||
|
/bin/bash -c "${commands}"
|
||||||
|
fi
|
||||||
|
|||||||
33
.buildkite/run-cpu-test-ppc64le.sh
Executable file
33
.buildkite/run-cpu-test-ppc64le.sh
Executable file
@@ -0,0 +1,33 @@
|
|||||||
|
# This script build the CPU docker image and run the offline inference inside the container.
|
||||||
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
docker build -t cpu-test -f Dockerfile.ppc64le .
|
||||||
|
|
||||||
|
# Setup cleanup
|
||||||
|
remove_docker_container() { docker rm -f cpu-test || true; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||||
|
source /etc/environment
|
||||||
|
#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
|
||||||
|
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test
|
||||||
|
|
||||||
|
# Run basic model test
|
||||||
|
docker exec cpu-test bash -c "
|
||||||
|
pip install pytest matplotlib einops transformers_stream_generator
|
||||||
|
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||||
|
|
||||||
|
# online inference
|
||||||
|
docker exec cpu-test bash -c "
|
||||||
|
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
|
||||||
|
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||||
|
python3 benchmarks/benchmark_serving.py \
|
||||||
|
--backend vllm \
|
||||||
|
--dataset-name random \
|
||||||
|
--model facebook/opt-125m \
|
||||||
|
--num-prompts 20 \
|
||||||
|
--endpoint /v1/completions \
|
||||||
|
--tokenizer facebook/opt-125m"
|
||||||
@@ -22,8 +22,19 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
|||||||
|
|
||||||
# Run basic model test
|
# Run basic model test
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
pip install pytest Pillow protobuf
|
pip install pytest matplotlib einops transformers_stream_generator
|
||||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \
|
||||||
|
--ignore=tests/models/test_oot_registration.py \
|
||||||
|
--ignore=tests/models/test_registry.py \
|
||||||
|
--ignore=tests/models/test_fp8.py \
|
||||||
|
--ignore=tests/models/test_jamba.py \
|
||||||
|
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||||
|
|
||||||
|
# Run compressed-tensor test
|
||||||
|
docker exec cpu-test bash -c "
|
||||||
|
pytest -s -v \
|
||||||
|
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||||
|
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
|
||||||
|
|
||||||
# online inference
|
# online inference
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
|
|||||||
@@ -12,5 +12,4 @@ remove_docker_container
|
|||||||
# For HF_TOKEN.
|
# For HF_TOKEN.
|
||||||
source /etc/environment
|
source /etc/environment
|
||||||
# Run a simple end-to-end example.
|
# Run a simple end-to-end example.
|
||||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
|
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||||
python3 /workspace/vllm/examples/offline_inference_tpu.py
|
|
||||||
|
|||||||
@@ -5,264 +5,428 @@
|
|||||||
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
||||||
# to generate the final pipeline yaml file.
|
# to generate the final pipeline yaml file.
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
# label(str): the name of the test. emoji allowed.
|
||||||
|
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||||
|
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||||
|
# command(str): the single command to run for tests. incompatible with commands.
|
||||||
|
# commands(list): the list of commands to run for test. incompatbile with command.
|
||||||
|
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
|
||||||
|
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
|
||||||
|
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
|
||||||
|
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
|
||||||
|
# in this case, commands must be specified. the first command runs on first host, the second
|
||||||
|
# command runs on the second host.
|
||||||
|
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
|
||||||
|
# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run.
|
||||||
|
|
||||||
|
# When adding a test
|
||||||
|
# - If the test belong to an existing group, add it there
|
||||||
|
# - If the test is short, add to any existing step
|
||||||
|
# - If the test takes more than 10min, then it is okay to create a new step.
|
||||||
|
# Note that all steps execute in parallel.
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- label: Async Engine, Inputs, Utils, Worker Test
|
##### fast check tests #####
|
||||||
fast_check: true
|
|
||||||
fast_check_only: true
|
|
||||||
commands:
|
|
||||||
- pytest -v -s async_engine # Async Engine
|
|
||||||
- pytest -v -s test_inputs.py
|
|
||||||
- pytest -v -s multimodal
|
|
||||||
- pytest -v -s test_utils.py # Utils
|
|
||||||
- pytest -v -s worker # Worker
|
|
||||||
|
|
||||||
- label: Metrics, Tracing Test
|
- label: Documentation Build # 2min
|
||||||
fast_check: true
|
|
||||||
fast_check_only: true
|
|
||||||
commands:
|
|
||||||
- pytest -v -s metrics # Metrics
|
|
||||||
- "pip install \
|
|
||||||
opentelemetry-sdk \
|
|
||||||
opentelemetry-api \
|
|
||||||
opentelemetry-exporter-otlp \
|
|
||||||
opentelemetry-semantic-conventions-ai" # Tracing
|
|
||||||
- pytest -v -s tracing
|
|
||||||
|
|
||||||
- label: Regression Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
fast_check: true
|
|
||||||
command: pytest -v -s test_regression.py
|
|
||||||
working_dir: "/vllm-workspace/tests" # optional
|
|
||||||
|
|
||||||
- label: AsyncEngine Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
command: pytest -v -s async_engine
|
|
||||||
|
|
||||||
- label: Basic Correctness Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
fast_check: true
|
|
||||||
commands:
|
|
||||||
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
|
|
||||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
|
||||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
|
||||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
|
||||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
|
||||||
|
|
||||||
- label: Core Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
fast_check: true
|
|
||||||
commands:
|
|
||||||
- pytest -v -s core
|
|
||||||
|
|
||||||
- label: Distributed Comm Ops Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
|
||||||
num_gpus: 2
|
|
||||||
commands:
|
|
||||||
- pytest -v -s distributed/test_comm_ops.py
|
|
||||||
- pytest -v -s distributed/test_shm_broadcast.py
|
|
||||||
|
|
||||||
- label: 2 Node Tests (4 GPUs in total)
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
|
||||||
num_gpus: 2
|
|
||||||
num_nodes: 2
|
|
||||||
commands:
|
|
||||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
|
||||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
|
||||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
|
||||||
|
|
||||||
- label: Distributed Tests (2 GPUs)
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
|
||||||
num_gpus: 2
|
|
||||||
commands:
|
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
|
||||||
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
|
||||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
|
||||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
|
||||||
|
|
||||||
- label: Distributed Tests (4 GPUs)
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
|
||||||
num_gpus: 4
|
|
||||||
fast_check: true
|
|
||||||
commands:
|
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
|
||||||
|
|
||||||
- label: Pipeline Parallelism Test
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
|
||||||
num_gpus: 4
|
|
||||||
commands:
|
|
||||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
|
||||||
|
|
||||||
- label: Engine Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
|
||||||
# OOM in the CI unless we run this separately
|
|
||||||
- pytest -v -s tokenization
|
|
||||||
|
|
||||||
- label: Entrypoints Test
|
|
||||||
fast_check: true
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
|
|
||||||
commands:
|
|
||||||
- pytest -v -s entrypoints/llm
|
|
||||||
- pytest -v -s entrypoints/openai
|
|
||||||
|
|
||||||
- label: Examples Test
|
|
||||||
working_dir: "/vllm-workspace/examples"
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
# install tensorizer for tensorize_vllm_model.py
|
|
||||||
- pip install awscli tensorizer
|
|
||||||
- python3 offline_inference.py
|
|
||||||
- python3 cpu_offload.py
|
|
||||||
- python3 offline_inference_with_prefix.py
|
|
||||||
- python3 llm_engine_example.py
|
|
||||||
- python3 offline_inference_vision_language.py
|
|
||||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
|
||||||
|
|
||||||
- label: Inputs Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
- pytest -v -s test_inputs.py
|
|
||||||
- pytest -v -s multimodal
|
|
||||||
|
|
||||||
# - label: Kernels Test %N
|
|
||||||
# #mirror_hardwares: [amd]
|
|
||||||
# commands:
|
|
||||||
# - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
|
||||||
# - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
|
||||||
# parallelism: 4
|
|
||||||
|
|
||||||
- label: Models Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
|
||||||
- pytest -v -s models -m \"not vlm\"
|
|
||||||
|
|
||||||
- label: Vision Language Models Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
- pytest -v -s models -m vlm
|
|
||||||
|
|
||||||
- label: Prefix Caching Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
- pytest -v -s prefix_caching
|
|
||||||
|
|
||||||
- label: Samplers Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
command: pytest -v -s samplers
|
|
||||||
|
|
||||||
- label: LogitsProcessor Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
command: pytest -v -s test_logits_processor.py
|
|
||||||
|
|
||||||
- label: Utils Test
|
|
||||||
commands:
|
|
||||||
- pytest -v -s test_utils.py
|
|
||||||
- pytest -v -s test_embedded_commit.py
|
|
||||||
|
|
||||||
- label: Worker Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
command: pytest -v -s worker
|
|
||||||
|
|
||||||
- label: Speculative decoding tests
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
# See https://github.com/vllm-project/vllm/issues/5152
|
|
||||||
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
|
||||||
- pytest -v -s spec_decode
|
|
||||||
|
|
||||||
# - label: LoRA Test %N
|
|
||||||
# #mirror_hardwares: [amd]
|
|
||||||
# command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
|
||||||
# parallelism: 4
|
|
||||||
|
|
||||||
# - label: LoRA Long Context (Distributed)
|
|
||||||
# #mirror_hardwares: [amd]
|
|
||||||
# num_gpus: 4
|
|
||||||
# # This test runs llama 13B, so it is required to run on 4 GPUs.
|
|
||||||
# commands:
|
|
||||||
# # FIXIT: find out which code initialize cuda before running the test
|
|
||||||
# # before the fix, we need to use spawn to test it
|
|
||||||
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
# - pytest -v -s -x lora/test_long_context.py
|
|
||||||
|
|
||||||
- label: Tensorizer Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
fast_check: true
|
|
||||||
commands:
|
|
||||||
- apt-get install -y curl libsodium23
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- pytest -v -s tensorizer_loader
|
|
||||||
|
|
||||||
- label: Metrics Test
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
command: pytest -v -s metrics
|
|
||||||
|
|
||||||
- label: Quantization Test
|
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
command: pytest -v -s quantization
|
|
||||||
|
|
||||||
- label: Tracing Test
|
|
||||||
commands:
|
|
||||||
- "pip install \
|
|
||||||
opentelemetry-sdk \
|
|
||||||
opentelemetry-api \
|
|
||||||
opentelemetry-exporter-otlp \
|
|
||||||
opentelemetry-semantic-conventions-ai"
|
|
||||||
- pytest -v -s tracing
|
|
||||||
|
|
||||||
- label: Benchmarks
|
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
|
||||||
mirror_hardwares: [amd]
|
|
||||||
commands:
|
|
||||||
- pip install aiohttp
|
|
||||||
- bash run-benchmarks.sh
|
|
||||||
|
|
||||||
- label: LM Eval Small Models
|
|
||||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
|
||||||
commands:
|
|
||||||
- pip install lm-eval
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
|
||||||
|
|
||||||
- label: LM Eval Large Models
|
|
||||||
gpu: a100
|
|
||||||
num_gpus: 4
|
|
||||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
|
||||||
commands:
|
|
||||||
- pip install lm-eval
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- bash ./run-tests.sh -c configs/models-large.txt -t 4
|
|
||||||
|
|
||||||
- label: Documentation Build
|
|
||||||
working_dir: "/vllm-workspace/test_docs/docs"
|
working_dir: "/vllm-workspace/test_docs/docs"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
no_gpu: True
|
no_gpu: True
|
||||||
commands:
|
commands:
|
||||||
- pip install -r requirements-docs.txt
|
- pip install -r requirements-docs.txt
|
||||||
- SPHINXOPTS=\"-W\" make html
|
- SPHINXOPTS=\"-W\" make html
|
||||||
|
# Check API reference (if it fails, you may have missing mock imports)
|
||||||
|
- grep \"sig sig-object py\" build/html/dev/sampling_params.html
|
||||||
|
|
||||||
- label: Distributed Tests (A100)
|
- label: Async Engine, Inputs, Utils, Worker Test # 15min
|
||||||
|
fast_check: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/async_engine
|
||||||
|
- tests/test_inputs
|
||||||
|
- tests/multimodal
|
||||||
|
- tests/test_utils
|
||||||
|
- tests/worker
|
||||||
|
commands:
|
||||||
|
- pytest -v -s async_engine # Async Engine
|
||||||
|
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||||
|
- pytest -v -s test_inputs.py
|
||||||
|
- pytest -v -s multimodal
|
||||||
|
- pytest -v -s test_utils.py # Utils
|
||||||
|
- pytest -v -s worker # Worker
|
||||||
|
|
||||||
|
- label: Basic Correctness Test # 30min
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
fast_check: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/basic_correctness
|
||||||
|
commands:
|
||||||
|
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||||
|
|
||||||
|
- label: Core Test # 10min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
fast_check: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/core
|
||||||
|
- vllm/distributed
|
||||||
|
- tests/core
|
||||||
|
commands:
|
||||||
|
- pytest -v -s core
|
||||||
|
|
||||||
|
- label: Entrypoints Test # 20min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
fast_check: true
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
commands:
|
||||||
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
|
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||||
|
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||||
|
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||||
|
- pytest -v -s entrypoints/openai
|
||||||
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
|
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
|
|
||||||
|
- label: Distributed Tests (4 GPUs) # 10min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 4
|
||||||
|
fast_check: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/distributed/
|
||||||
|
- vllm/core/
|
||||||
|
- tests/distributed
|
||||||
|
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||||
|
commands:
|
||||||
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||||
|
|
||||||
|
- label: Metrics, Tracing Test # 10min
|
||||||
|
num_gpus: 2
|
||||||
|
fast_check: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/metrics
|
||||||
|
- tests/tracing
|
||||||
|
commands:
|
||||||
|
- pytest -v -s metrics
|
||||||
|
- "pip install \
|
||||||
|
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||||
|
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||||
|
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
|
||||||
|
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'"
|
||||||
|
- pytest -v -s tracing
|
||||||
|
|
||||||
|
##### fast check tests #####
|
||||||
|
##### 1 GPU test #####
|
||||||
|
|
||||||
|
- label: Regression Test # 5min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/test_regression
|
||||||
|
command: pytest -v -s test_regression.py
|
||||||
|
working_dir: "/vllm-workspace/tests" # optional
|
||||||
|
|
||||||
|
- label: Engine Test # 10min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/engine
|
||||||
|
- tests/tokenization
|
||||||
|
commands:
|
||||||
|
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||||
|
# OOM in the CI unless we run this separately
|
||||||
|
- pytest -v -s tokenization
|
||||||
|
|
||||||
|
- label: Examples Test # 12min
|
||||||
|
working_dir: "/vllm-workspace/examples"
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/entrypoints
|
||||||
|
- examples/
|
||||||
|
commands:
|
||||||
|
- pip install awscli tensorizer # for llava example and tensorizer test
|
||||||
|
- python3 offline_inference.py
|
||||||
|
- python3 cpu_offload.py
|
||||||
|
- python3 offline_inference_chat.py
|
||||||
|
- python3 offline_inference_with_prefix.py
|
||||||
|
- python3 llm_engine_example.py
|
||||||
|
- python3 offline_inference_vision_language.py
|
||||||
|
- python3 offline_inference_vision_language_multi_image.py
|
||||||
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
|
- python3 offline_inference_encoder_decoder.py
|
||||||
|
|
||||||
|
- label: Models Test # 1hr10min
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/models
|
||||||
|
commands:
|
||||||
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
|
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||||
|
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
|
||||||
|
|
||||||
|
- label: torch compile integration test
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
commands:
|
||||||
|
- pytest -v -s ./compile/test_full_graph.py
|
||||||
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
|
|
||||||
|
|
||||||
|
- label: Vision Language Models Test # 42min
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
commands:
|
||||||
|
- pytest -v -s models -m vlm
|
||||||
|
|
||||||
|
- label: Prefix Caching Test # 7min
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/prefix_caching
|
||||||
|
commands:
|
||||||
|
- pytest -v -s prefix_caching
|
||||||
|
|
||||||
|
- label: Samplers Test # 18min
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/model_executor/layers
|
||||||
|
- vllm/sampling_metadata.py
|
||||||
|
- tests/samplers
|
||||||
|
commands:
|
||||||
|
- pytest -v -s samplers
|
||||||
|
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||||
|
|
||||||
|
- label: LogitsProcessor Test # 5min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/model_executor/layers
|
||||||
|
- tests/test_logits_processor
|
||||||
|
command: pytest -v -s test_logits_processor.py
|
||||||
|
|
||||||
|
- label: Speculative decoding tests # 22min
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/spec_decode
|
||||||
|
- tests/spec_decode
|
||||||
|
commands:
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/5152
|
||||||
|
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
|
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||||
|
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||||
|
|
||||||
|
- label: LoRA Test %N # 30min each
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/lora
|
||||||
|
- tests/lora
|
||||||
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||||
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: Kernels Test %N # 30min each
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/
|
||||||
|
- vllm/attention
|
||||||
|
- tests/kernels
|
||||||
|
commands:
|
||||||
|
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: Tensorizer Test # 11min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
soft_fail: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/model_executor/model_loader
|
||||||
|
- tests/tensorizer_loader
|
||||||
|
commands:
|
||||||
|
- apt-get update && apt-get install -y curl libsodium23
|
||||||
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
- pytest -v -s tensorizer_loader
|
||||||
|
|
||||||
|
- label: Benchmarks # 9min
|
||||||
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
source_file_dependencies:
|
||||||
|
- benchmarks/
|
||||||
|
commands:
|
||||||
|
- pip install aiohttp
|
||||||
|
- bash run-benchmarks.sh
|
||||||
|
|
||||||
|
- label: Quantization Test # 15min
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/
|
||||||
|
- vllm/model_executor/layers/quantization
|
||||||
|
- tests/quantization
|
||||||
|
command: pytest -v -s quantization
|
||||||
|
|
||||||
|
- label: LM Eval Small Models # 53min
|
||||||
|
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/
|
||||||
|
- vllm/model_executor/layers/quantization
|
||||||
|
commands:
|
||||||
|
- pip install lm-eval
|
||||||
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||||
|
|
||||||
|
- label: OpenAI-Compatible Tool Use # 20 min
|
||||||
|
fast_check: false
|
||||||
|
mirror_hardwares: [ amd ]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/tool_use
|
||||||
|
commands:
|
||||||
|
- pytest -v -s tool_use
|
||||||
|
|
||||||
|
##### 1 GPU test #####
|
||||||
|
##### multi gpus test #####
|
||||||
|
|
||||||
|
- label: Distributed Comm Ops Test # 7min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/distributed
|
||||||
|
- tests/distributed
|
||||||
|
commands:
|
||||||
|
- pytest -v -s distributed/test_comm_ops.py
|
||||||
|
- pytest -v -s distributed/test_shm_broadcast.py
|
||||||
|
|
||||||
|
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
num_nodes: 2
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/distributed/
|
||||||
|
- vllm/engine/
|
||||||
|
- vllm/executor/
|
||||||
|
- vllm/model_executor/models/
|
||||||
|
- tests/distributed/
|
||||||
|
commands:
|
||||||
|
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||||
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||||
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||||
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
|
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||||
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||||
|
|
||||||
|
- label: Distributed Tests (2 GPUs) # 28min
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/distributed/
|
||||||
|
- vllm/engine/
|
||||||
|
- vllm/executor/
|
||||||
|
- vllm/model_executor/models/
|
||||||
|
- tests/distributed/
|
||||||
|
commands:
|
||||||
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
|
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
|
||||||
|
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
|
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||||
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
|
- pytest -v -s distributed/test_distributed_oot.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||||
|
|
||||||
|
- label: Multi-step Tests (4 GPUs) # 21min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 4
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/model_executor/layers/sampler.py
|
||||||
|
- vllm/sequence.py
|
||||||
|
- vllm/worker/worker_base.py
|
||||||
|
- vllm/worker/worker.py
|
||||||
|
- vllm/worker/multi_step_worker.py
|
||||||
|
- vllm/worker/model_runner_base.py
|
||||||
|
- vllm/worker/model_runner.py
|
||||||
|
- vllm/worker/multi_step_model_runner.py
|
||||||
|
- vllm/engine
|
||||||
|
- tests/multi_step
|
||||||
|
commands:
|
||||||
|
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||||
|
- pytest -v -s multi_step/test_correctness_llm.py
|
||||||
|
|
||||||
|
- label: Pipeline Parallelism Test # 23min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 4
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/distributed/
|
||||||
|
- vllm/engine/
|
||||||
|
- vllm/executor/
|
||||||
|
- vllm/model_executor/models/
|
||||||
|
- tests/distributed/
|
||||||
|
commands:
|
||||||
|
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||||
|
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
|
|
||||||
|
- label: LoRA Long Context (Distributed) # 11min
|
||||||
|
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||||
|
num_gpus: 4
|
||||||
|
soft_fail: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/lora
|
||||||
|
- tests/lora/test_long_context
|
||||||
|
commands:
|
||||||
|
# FIXIT: find out which code initialize cuda before running the test
|
||||||
|
# before the fix, we need to use spawn to test it
|
||||||
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
- pytest -v -s -x lora/test_long_context.py
|
||||||
|
|
||||||
|
- label: Weight Loading Multiple GPU Test
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/weight_loading
|
||||||
|
commands:
|
||||||
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||||
|
|
||||||
|
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
gpu: a100
|
||||||
|
optional: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/weight_loading
|
||||||
|
commands:
|
||||||
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||||
|
|
||||||
|
|
||||||
|
##### multi gpus test #####
|
||||||
|
##### A100 test #####
|
||||||
|
|
||||||
|
- label: Distributed Tests (A100) # optional
|
||||||
gpu: a100
|
gpu: a100
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
commands:
|
commands:
|
||||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
|
||||||
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- pytest -v -s -x lora/test_mixtral.py
|
- pytest -v -s -x lora/test_mixtral.py
|
||||||
|
|
||||||
|
- label: LM Eval Large Models # optional
|
||||||
|
gpu: a100
|
||||||
|
num_gpus: 4
|
||||||
|
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/
|
||||||
|
- vllm/model_executor/layers/quantization
|
||||||
|
commands:
|
||||||
|
- pip install lm-eval
|
||||||
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
- bash ./run-tests.sh -c configs/models-large.txt -t 4
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
vllm/*.so
|
vllm/*.so
|
||||||
|
/.venv
|
||||||
|
/build
|
||||||
|
dist
|
||||||
|
|||||||
7
.github/ISSUE_TEMPLATE/100-documentation.yml
vendored
7
.github/ISSUE_TEMPLATE/100-documentation.yml
vendored
@@ -20,3 +20,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
7
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
7
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@@ -38,3 +38,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
7
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
7
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@@ -36,3 +36,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
23
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
23
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@@ -20,11 +20,25 @@ body:
|
|||||||
```
|
```
|
||||||
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||||
value: |
|
value: |
|
||||||
|
<details>
|
||||||
|
<summary>The output of `python collect_env.py`</summary>
|
||||||
|
|
||||||
```text
|
```text
|
||||||
The output of `python collect_env.py`
|
Your output of `python collect_env.py` here
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Model Input Dumps
|
||||||
|
description: |
|
||||||
|
If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process.
|
||||||
|
placeholder: |
|
||||||
|
Upload the dumped input file.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: 🐛 Describe the bug
|
label: 🐛 Describe the bug
|
||||||
@@ -84,3 +98,10 @@ body:
|
|||||||
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
|
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
|
||||||
|
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
@@ -29,3 +29,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
7
.github/ISSUE_TEMPLATE/600-new model.yml
vendored
7
.github/ISSUE_TEMPLATE/600-new model.yml
vendored
@@ -31,3 +31,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
@@ -50,3 +50,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
7
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
7
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
@@ -47,3 +47,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
@@ -19,3 +19,10 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
value: >
|
value: >
|
||||||
Thanks for contributing 🎉!
|
Thanks for contributing 🎉!
|
||||||
|
- type: checkboxes
|
||||||
|
id: askllm
|
||||||
|
attributes:
|
||||||
|
label: Before submitting a new issue...
|
||||||
|
options:
|
||||||
|
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||||
|
required: true
|
||||||
|
|||||||
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
|
|||||||
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
|
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
|
||||||
</ul>
|
</ul>
|
||||||
|
|
||||||
|
<h3>Adding or changing kernels</h3>
|
||||||
|
<p>Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.</p>
|
||||||
|
<ul>
|
||||||
|
<li>Make sure custom ops are registered following PyTorch guidelines: <a href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom C++ and CUDA Operators</a> and <a href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The Custom Operators Manual</a></li>
|
||||||
|
<li>Custom operations that return <code>Tensors</code> require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.</li>
|
||||||
|
<li>Use <a href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a> to test the function registration and meta-function for any registered ops. See <code>tests/kernels</code> for examples.</li>
|
||||||
|
<li>When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.</li>
|
||||||
|
<li>If a new custom type is needed, see the following document: <a href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom Class Support in PT2</a>.
|
||||||
|
</ul>
|
||||||
|
|
||||||
<h3>Notes for Large Changes</h3>
|
<h3>Notes for Large Changes</h3>
|
||||||
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>
|
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>
|
||||||
|
|
||||||
|
|||||||
23
.github/workflows/add_label_ready_comment.yml
vendored
23
.github/workflows/add_label_ready_comment.yml
vendored
@@ -1,23 +0,0 @@
|
|||||||
name: Add Ready Label on Ready Comment
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
add-ready-label:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready')
|
|
||||||
steps:
|
|
||||||
- name: Add label
|
|
||||||
uses: actions/github-script@v5
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
github.rest.issues.addLabels({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: context.issue.number,
|
|
||||||
labels: ['ready']
|
|
||||||
})
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
5
.github/workflows/clang-format.yml
vendored
5
.github/workflows/clang-format.yml
vendored
@@ -30,6 +30,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
EXCLUDES=(
|
EXCLUDES=(
|
||||||
'csrc/moe/topk_softmax_kernels.cu'
|
'csrc/moe/topk_softmax_kernels.cu'
|
||||||
|
'csrc/quantization/gguf/ggml-common.h'
|
||||||
|
'csrc/quantization/gguf/dequantize.cuh'
|
||||||
|
'csrc/quantization/gguf/vecdotq.cuh'
|
||||||
|
'csrc/quantization/gguf/mmq.cuh'
|
||||||
|
'csrc/quantization/gguf/mmvq.cuh'
|
||||||
)
|
)
|
||||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||||
|
|||||||
4
.github/workflows/mypy.yaml
vendored
4
.github/workflows/mypy.yaml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install mypy==1.9.0
|
pip install mypy==1.11.1
|
||||||
pip install types-setuptools
|
pip install types-setuptools
|
||||||
pip install types-PyYAML
|
pip install types-PyYAML
|
||||||
pip install types-requests
|
pip install types-requests
|
||||||
@@ -35,10 +35,8 @@ jobs:
|
|||||||
mypy
|
mypy
|
||||||
mypy tests --follow-imports skip
|
mypy tests --follow-imports skip
|
||||||
mypy vllm/attention --follow-imports skip
|
mypy vllm/attention --follow-imports skip
|
||||||
mypy vllm/core --follow-imports skip
|
|
||||||
mypy vllm/distributed --follow-imports skip
|
mypy vllm/distributed --follow-imports skip
|
||||||
mypy vllm/engine --follow-imports skip
|
mypy vllm/engine --follow-imports skip
|
||||||
mypy vllm/entrypoints --follow-imports skip
|
|
||||||
mypy vllm/executor --follow-imports skip
|
mypy vllm/executor --follow-imports skip
|
||||||
mypy vllm/lora --follow-imports skip
|
mypy vllm/lora --follow-imports skip
|
||||||
mypy vllm/model_executor --follow-imports skip
|
mypy vllm/model_executor --follow-imports skip
|
||||||
|
|||||||
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
owner: context.repo.owner,
|
owner: context.repo.owner,
|
||||||
repo: context.repo.repo,
|
repo: context.repo.repo,
|
||||||
issue_number: context.issue.number,
|
issue_number: context.issue.number,
|
||||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||||
})
|
})
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
name: Remove ready Label on notready Comment
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
add-ready-label:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready')
|
|
||||||
steps:
|
|
||||||
- name: Remove ready label
|
|
||||||
uses: actions/github-script@v5
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
github.rest.issues.removeLabel({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: context.issue.number,
|
|
||||||
name: 'ready'
|
|
||||||
})
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -87,6 +87,9 @@ target/
|
|||||||
profile_default/
|
profile_default/
|
||||||
ipython_config.py
|
ipython_config.py
|
||||||
|
|
||||||
|
# generated files
|
||||||
|
**/generated/**
|
||||||
|
|
||||||
# pyenv
|
# pyenv
|
||||||
# For a library or package, you might want to ignore these files since the code is
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
# intended to run in multiple environments; otherwise, check them in:
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
@@ -189,4 +192,4 @@ _build/
|
|||||||
hip_compat.h
|
hip_compat.h
|
||||||
|
|
||||||
# Benchmark dataset
|
# Benchmark dataset
|
||||||
*.json
|
benchmarks/*.json
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
cmake_minimum_required(VERSION 3.21)
|
cmake_minimum_required(VERSION 3.26)
|
||||||
|
|
||||||
project(vllm_extensions LANGUAGES CXX)
|
project(vllm_extensions LANGUAGES CXX)
|
||||||
|
|
||||||
@@ -10,6 +10,9 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
|||||||
|
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||||
|
|
||||||
|
# Suppress potential warnings about unused manually-specified variables
|
||||||
|
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Supported python versions. These versions will be searched in order, the
|
# Supported python versions. These versions will be searched in order, the
|
||||||
# first match will be selected. These should be kept in sync with setup.py.
|
# first match will be selected. These should be kept in sync with setup.py.
|
||||||
@@ -178,7 +181,6 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/pos_encoding_kernels.cu"
|
"csrc/pos_encoding_kernels.cu"
|
||||||
"csrc/activation_kernels.cu"
|
"csrc/activation_kernels.cu"
|
||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
@@ -193,13 +195,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
cutlass
|
cutlass
|
||||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||||
# CUTLASS 3.5.1
|
GIT_TAG v3.5.1
|
||||||
GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
|
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
|
|
||||||
|
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||||
|
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||||
|
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(cutlass)
|
FetchContent_MakeAvailable(cutlass)
|
||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
|
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||||
|
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||||
@@ -208,6 +216,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||||
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
@@ -226,6 +235,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"-gencode arch=compute_90a,code=sm_90a")
|
"-gencode arch=compute_90a,code=sm_90a")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Machete kernels
|
||||||
|
|
||||||
|
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||||
|
#
|
||||||
|
# For the Machete kernels we automatically generate sources for various
|
||||||
|
# preselected input type pairs and schedules.
|
||||||
|
# Generate sources:
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E env
|
||||||
|
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||||
|
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
|
||||||
|
RESULT_VARIABLE machete_generation_result
|
||||||
|
OUTPUT_VARIABLE machete_generation_output
|
||||||
|
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||||
|
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||||
|
)
|
||||||
|
|
||||||
|
if (NOT machete_generation_result EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Machete generation failed."
|
||||||
|
" Result: \"${machete_generation_result}\""
|
||||||
|
"\nCheck the log for details: "
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
||||||
|
else()
|
||||||
|
message(STATUS "Machete generation completed successfully.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Add machete generated sources
|
||||||
|
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
|
||||||
|
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
|
||||||
|
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
|
||||||
|
|
||||||
|
set_source_files_properties(
|
||||||
|
${MACHETE_GEN_SOURCES}
|
||||||
|
PROPERTIES
|
||||||
|
COMPILE_FLAGS
|
||||||
|
"-gencode arch=compute_90a,code=sm_90a")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
|
||||||
|
# raise an error if the user that this was built with an incompatible
|
||||||
|
# CUDA version)
|
||||||
|
list(APPEND VLLM_EXT_SRC
|
||||||
|
csrc/quantization/machete/machete_pytorch.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
@@ -239,6 +294,12 @@ define_gpu_extension_target(
|
|||||||
USE_SABI 3
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
|
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
|
||||||
|
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
|
||||||
|
# driver API. This causes problems when linking with earlier versions of CUDA.
|
||||||
|
# Setting this variable sidesteps the issue by calling the driver directly.
|
||||||
|
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||||
|
|
||||||
#
|
#
|
||||||
# _moe_C extension
|
# _moe_C extension
|
||||||
#
|
#
|
||||||
@@ -247,6 +308,11 @@ set(VLLM_MOE_EXT_SRC
|
|||||||
"csrc/moe/torch_bindings.cpp"
|
"csrc/moe/torch_bindings.cpp"
|
||||||
"csrc/moe/topk_softmax_kernels.cu")
|
"csrc/moe/topk_softmax_kernels.cu")
|
||||||
|
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
list(APPEND VLLM_MOE_EXT_SRC
|
||||||
|
"csrc/moe/marlin_moe_ops.cu")
|
||||||
|
endif()
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
_moe_C
|
_moe_C
|
||||||
DESTINATION vllm
|
DESTINATION vllm
|
||||||
|
|||||||
128
CODE_OF_CONDUCT.md
Normal file
128
CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
|
||||||
|
# vLLM Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
We as members, contributors, and leaders pledge to make participation in our
|
||||||
|
community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||||
|
identity and expression, level of experience, education, socioeconomic status,
|
||||||
|
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||||
|
identity and orientation.
|
||||||
|
|
||||||
|
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||||
|
diverse, inclusive, and healthy community.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to a positive environment for our
|
||||||
|
community include:
|
||||||
|
|
||||||
|
* Demonstrating empathy and kindness toward other people
|
||||||
|
* Being respectful of differing opinions, viewpoints, and experiences
|
||||||
|
* Giving and gracefully accepting constructive feedback
|
||||||
|
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||||
|
and learning from the experience
|
||||||
|
* Focusing on what is best not just for us as individuals, but for the overall
|
||||||
|
community
|
||||||
|
|
||||||
|
Examples of unacceptable behavior include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||||
|
any kind
|
||||||
|
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or email address,
|
||||||
|
without their explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Enforcement Responsibilities
|
||||||
|
|
||||||
|
Community leaders are responsible for clarifying and enforcing our standards of
|
||||||
|
acceptable behavior and will take appropriate and fair corrective action in
|
||||||
|
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||||
|
or harmful.
|
||||||
|
|
||||||
|
Community leaders have the right and responsibility to remove, edit, or reject
|
||||||
|
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||||
|
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||||
|
decisions when appropriate.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all community spaces, and also applies when
|
||||||
|
an individual is officially representing the community in public spaces.
|
||||||
|
Examples of representing our community include using an official email address,
|
||||||
|
posting via an official social media account, or acting as an appointed
|
||||||
|
representative at an online or offline/IRL event.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported to the community leaders responsible for enforcement in the #code-of-conduct
|
||||||
|
channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g).
|
||||||
|
All complaints will be reviewed and investigated promptly and fairly.
|
||||||
|
|
||||||
|
All community leaders are obligated to respect the privacy and security of the
|
||||||
|
reporter of any incident.
|
||||||
|
|
||||||
|
## Enforcement Guidelines
|
||||||
|
|
||||||
|
Community leaders will follow these Community Impact Guidelines in determining
|
||||||
|
the consequences for any action they deem in violation of this Code of Conduct:
|
||||||
|
|
||||||
|
### 1. Correction
|
||||||
|
|
||||||
|
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||||
|
unprofessional or unwelcome in the community.
|
||||||
|
|
||||||
|
**Consequence**: A private, written warning from community leaders, providing
|
||||||
|
clarity around the nature of the violation and an explanation of why the
|
||||||
|
behavior was inappropriate. A public apology may be requested.
|
||||||
|
|
||||||
|
### 2. Warning
|
||||||
|
|
||||||
|
**Community Impact**: A violation through a single incident or series of
|
||||||
|
actions.
|
||||||
|
|
||||||
|
**Consequence**: A warning with consequences for continued behavior. No
|
||||||
|
interaction with the people involved, including unsolicited interaction with
|
||||||
|
those enforcing the Code of Conduct, for a specified period of time. This
|
||||||
|
includes avoiding interactions in community spaces as well as external channels
|
||||||
|
like social media. Violating these terms may lead to a temporary or permanent
|
||||||
|
ban.
|
||||||
|
|
||||||
|
### 3. Temporary Ban
|
||||||
|
|
||||||
|
**Community Impact**: A serious violation of community standards, including
|
||||||
|
sustained inappropriate behavior.
|
||||||
|
|
||||||
|
**Consequence**: A temporary ban from any sort of interaction or public
|
||||||
|
communication with the community for a specified period of time. No public or
|
||||||
|
private interaction with the people involved, including unsolicited interaction
|
||||||
|
with those enforcing the Code of Conduct, is allowed during this period.
|
||||||
|
Violating these terms may lead to a permanent ban.
|
||||||
|
|
||||||
|
### 4. Permanent Ban
|
||||||
|
|
||||||
|
**Community Impact**: Demonstrating a pattern of violation of community
|
||||||
|
standards, including sustained inappropriate behavior, harassment of an
|
||||||
|
individual, or aggression toward or disparagement of classes of individuals.
|
||||||
|
|
||||||
|
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||||
|
community.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/),
|
||||||
|
version 2.1, available at
|
||||||
|
[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html).
|
||||||
|
|
||||||
|
Community Impact Guidelines were inspired by
|
||||||
|
[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion).
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see the
|
||||||
|
[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at
|
||||||
|
[Contributor Covenant translations](https://www.contributor-covenant.org/translations).
|
||||||
|
|
||||||
108
Dockerfile
108
Dockerfile
@@ -9,28 +9,23 @@ ARG CUDA_VERSION=12.4.1
|
|||||||
#################### BASE BUILD IMAGE ####################
|
#################### BASE BUILD IMAGE ####################
|
||||||
# prepare basic build environment
|
# prepare basic build environment
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||||
|
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG CUDA_VERSION=12.4.1
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.12
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install Python and other dependencies
|
||||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y ccache software-properties-common \
|
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||||
&& python3 --version
|
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||||
|
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||||
RUN apt-get update -y \
|
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||||
&& apt-get install -y git curl sudo
|
&& python3 --version && python3 -m pip --version
|
||||||
|
|
||||||
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
|
||||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
|
||||||
RUN python3 -m pip --version
|
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
@@ -42,14 +37,10 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
# install build and runtime dependencies
|
# install build and runtime dependencies
|
||||||
COPY requirements-common.txt requirements-common.txt
|
COPY requirements-common.txt requirements-common.txt
|
||||||
COPY requirements-adag.txt requirements-adag.txt
|
|
||||||
COPY requirements-cuda.txt requirements-cuda.txt
|
COPY requirements-cuda.txt requirements-cuda.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install -r requirements-cuda.txt
|
python3 -m pip install -r requirements-cuda.txt
|
||||||
|
|
||||||
COPY requirements-mamba.txt requirements-mamba.txt
|
|
||||||
RUN python3 -m pip install packaging
|
|
||||||
RUN python3 -m pip install -r requirements-mamba.txt
|
|
||||||
|
|
||||||
# cuda arch list used by torch
|
# cuda arch list used by torch
|
||||||
# can be useful for both `dev` and `test`
|
# can be useful for both `dev` and `test`
|
||||||
@@ -62,24 +53,18 @@ ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
|||||||
#################### WHEEL BUILD IMAGE ####################
|
#################### WHEEL BUILD IMAGE ####################
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.10
|
|
||||||
|
|
||||||
# install build dependencies
|
# install build dependencies
|
||||||
COPY requirements-build.txt requirements-build.txt
|
COPY requirements-build.txt requirements-build.txt
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install -r requirements-build.txt
|
python3 -m pip install -r requirements-build.txt
|
||||||
|
|
||||||
# install compiler cache to speed up compilation leveraging local or remote caching
|
|
||||||
RUN apt-get update -y && apt-get install -y ccache
|
|
||||||
|
|
||||||
# files and directories related to build wheels
|
# files and directories related to build wheels
|
||||||
COPY csrc csrc
|
COPY csrc csrc
|
||||||
COPY setup.py setup.py
|
COPY setup.py setup.py
|
||||||
COPY cmake cmake
|
COPY cmake cmake
|
||||||
COPY CMakeLists.txt CMakeLists.txt
|
COPY CMakeLists.txt CMakeLists.txt
|
||||||
COPY requirements-common.txt requirements-common.txt
|
COPY requirements-common.txt requirements-common.txt
|
||||||
COPY requirements-adag.txt requirements-adag.txt
|
|
||||||
COPY requirements-cuda.txt requirements-cuda.txt
|
COPY requirements-cuda.txt requirements-cuda.txt
|
||||||
COPY pyproject.toml pyproject.toml
|
COPY pyproject.toml pyproject.toml
|
||||||
COPY vllm vllm
|
COPY vllm vllm
|
||||||
@@ -95,6 +80,8 @@ ARG buildkite_commit
|
|||||||
ENV BUILDKITE_COMMIT=${buildkite_commit}
|
ENV BUILDKITE_COMMIT=${buildkite_commit}
|
||||||
|
|
||||||
ARG USE_SCCACHE
|
ARG USE_SCCACHE
|
||||||
|
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||||
|
ARG SCCACHE_REGION_NAME=us-west-2
|
||||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||||
@@ -103,12 +90,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
&& tar -xzf sccache.tar.gz \
|
&& tar -xzf sccache.tar.gz \
|
||||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||||
&& if [ "$CUDA_VERSION" = "11.8.0" ]; then \
|
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||||
export SCCACHE_BUCKET=vllm-build-sccache-2; \
|
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||||
else \
|
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||||
export SCCACHE_BUCKET=vllm-build-sccache; \
|
|
||||||
fi \
|
|
||||||
&& export SCCACHE_REGION=us-west-2 \
|
|
||||||
&& export CMAKE_BUILD_TYPE=Release \
|
&& export CMAKE_BUILD_TYPE=Release \
|
||||||
&& sccache --show-stats \
|
&& sccache --show-stats \
|
||||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||||
@@ -122,10 +106,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
|||||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||||
RUN python3 check-wheel-size.py dist
|
# Default max size of the wheel is 250MB
|
||||||
|
ARG VLLM_MAX_SIZE_MB=250
|
||||||
|
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||||
|
ARG RUN_WHEEL_CHECK=true
|
||||||
|
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||||
|
python3 check-wheel-size.py dist; \
|
||||||
|
else \
|
||||||
|
echo "Skipping wheel size check."; \
|
||||||
|
fi
|
||||||
#################### EXTENSION Build IMAGE ####################
|
#################### EXTENSION Build IMAGE ####################
|
||||||
|
|
||||||
#################### DEV IMAGE ####################
|
#################### DEV IMAGE ####################
|
||||||
@@ -138,45 +129,31 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
python3 -m pip install -r requirements-dev.txt
|
python3 -m pip install -r requirements-dev.txt
|
||||||
|
|
||||||
#################### DEV IMAGE ####################
|
#################### DEV IMAGE ####################
|
||||||
#################### MAMBA Build IMAGE ####################
|
|
||||||
FROM dev as mamba-builder
|
|
||||||
# max jobs used for build
|
|
||||||
ARG max_jobs=2
|
|
||||||
ENV MAX_JOBS=${max_jobs}
|
|
||||||
|
|
||||||
WORKDIR /usr/src/mamba
|
|
||||||
|
|
||||||
COPY requirements-mamba.txt requirements-mamba.txt
|
|
||||||
|
|
||||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
|
||||||
RUN pip --verbose wheel -r requirements-mamba.txt \
|
|
||||||
--no-build-isolation --no-deps --no-cache-dir
|
|
||||||
|
|
||||||
#################### MAMBA Build IMAGE ####################
|
|
||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
# image with vLLM installed
|
# image with vLLM installed
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG CUDA_VERSION=12.4.1
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.12
|
||||||
WORKDIR /vllm-workspace
|
WORKDIR /vllm-workspace
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||||
|
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||||
|
|
||||||
|
# Install Python and other dependencies
|
||||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y ccache software-properties-common \
|
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
||||||
|
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||||
&& python3 --version
|
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||||
|
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||||
RUN apt-get update -y \
|
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||||
&& apt-get install -y python3-pip git vim curl libibverbs-dev
|
&& python3 --version && python3 -m pip --version
|
||||||
|
|
||||||
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
|
||||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
|
||||||
RUN python3 -m pip --version
|
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
@@ -189,12 +166,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
|||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install dist/*.whl --verbose
|
python3 -m pip install dist/*.whl --verbose
|
||||||
|
|
||||||
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
|
||||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
. /etc/environment && \
|
||||||
|
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
@@ -206,6 +180,10 @@ FROM vllm-base AS test
|
|||||||
ADD . /vllm-workspace/
|
ADD . /vllm-workspace/
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
# install development dependencies (for testing)
|
||||||
|
# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels
|
||||||
|
# This installation must complete before the test dependencies are collected and installed.
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
python3 -m pip install "setuptools>=74.1.1"
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install -r requirements-dev.txt
|
python3 -m pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
|||||||
@@ -2,37 +2,66 @@
|
|||||||
|
|
||||||
FROM ubuntu:22.04 AS cpu-test-1
|
FROM ubuntu:22.04 AS cpu-test-1
|
||||||
|
|
||||||
RUN apt-get update -y \
|
ENV CCACHE_DIR=/root/.cache/ccache
|
||||||
&& apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
|
||||||
|
ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/var/cache/apt \
|
||||||
|
apt-get update -y \
|
||||||
|
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
||||||
|
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||||
|
|
||||||
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
||||||
# intel-openmp provides additional performance improvement vs. openmp
|
# intel-openmp provides additional performance improvement vs. openmp
|
||||||
# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
|
# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
|
||||||
RUN pip install intel-openmp
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install intel-openmp
|
||||||
|
|
||||||
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
|
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so"
|
||||||
|
|
||||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||||
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
RUN pip install --upgrade pip \
|
ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu
|
||||||
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||||
|
pip install --upgrade pip && \
|
||||||
|
pip install -r requirements-build.txt
|
||||||
|
|
||||||
|
# install oneDNN
|
||||||
|
RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||||
|
cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
|
||||||
|
-DONEDNN_BUILD_DOC=OFF \
|
||||||
|
-DONEDNN_BUILD_EXAMPLES=OFF \
|
||||||
|
-DONEDNN_BUILD_TESTS=OFF \
|
||||||
|
-DONEDNN_BUILD_GRAPH=OFF \
|
||||||
|
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
|
||||||
|
-DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
|
||||||
|
cmake --build ./oneDNN/build --target install --config Release
|
||||||
|
|
||||||
FROM cpu-test-1 AS build
|
FROM cpu-test-1 AS build
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
WORKDIR /workspace/vllm
|
||||||
|
|
||||||
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||||
|
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||||
|
pip install -v -r requirements-cpu.txt
|
||||||
|
|
||||||
|
COPY ./ ./
|
||||||
|
|
||||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
||||||
ARG VLLM_CPU_DISABLE_AVX512
|
ARG VLLM_CPU_DISABLE_AVX512
|
||||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||||
|
|
||||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
--mount=type=cache,target=/root/.cache/ccache \
|
||||||
|
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
|
||||||
|
pip install dist/*.whl
|
||||||
|
|
||||||
WORKDIR /workspace/
|
WORKDIR /workspace/
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
# default base image
|
# default base image
|
||||||
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
|
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.19.1-ubuntu20.04"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
|
|
||||||
RUN echo "Base image is $BASE_IMAGE"
|
RUN echo "Base image is $BASE_IMAGE"
|
||||||
|
|
||||||
# Install some basic utilities
|
# Install some basic utilities
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
RUN apt-get update \
|
||||||
|
&& apt-get install python3 python3-pip -y \
|
||||||
|
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
### Mount Point ###
|
### Mount Point ###
|
||||||
# When launching the container, mount the code directory to /app
|
# When launching the container, mount the code directory to /app
|
||||||
|
|||||||
@@ -4,7 +4,8 @@
|
|||||||
FROM ubuntu:22.04 AS dev
|
FROM ubuntu:22.04 AS dev
|
||||||
|
|
||||||
RUN apt-get update -y && \
|
RUN apt-get update -y && \
|
||||||
apt-get install -y python3-pip git
|
apt-get install -y python3-pip git && \
|
||||||
|
apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
# copy requirements
|
# copy requirements
|
||||||
@@ -21,7 +22,7 @@ COPY setup.py /workspace/vllm/
|
|||||||
# install build requirements
|
# install build requirements
|
||||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
|
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
|
||||||
# build vLLM with OpenVINO backend
|
# build vLLM with OpenVINO backend
|
||||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
|
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
|
||||||
|
|
||||||
COPY examples/ /workspace/vllm/examples
|
COPY examples/ /workspace/vllm/examples
|
||||||
COPY benchmarks/ /workspace/vllm/benchmarks
|
COPY benchmarks/ /workspace/vllm/benchmarks
|
||||||
|
|||||||
@@ -2,21 +2,26 @@ FROM mambaorg/micromamba
|
|||||||
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||||
|
|
||||||
|
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
# Some packages in requirements-cpu are installed here
|
# Some packages in requirements-cpu are installed here
|
||||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||||
# Currently these may not be available for venv or pip directly
|
# Currently these may not be available for venv or pip directly
|
||||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes
|
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
WORKDIR /workspace/vllm
|
||||||
|
|
||||||
# These packages will be in rocketce eventually
|
# These packages will be in rocketce eventually
|
||||||
RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
|
RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
|
||||||
|
|
||||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||||
|
|
||||||
WORKDIR /vllm-workspace
|
WORKDIR /workspace/
|
||||||
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]
|
|
||||||
|
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||||
|
|
||||||
|
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||||
|
|||||||
@@ -1,23 +1,20 @@
|
|||||||
ARG NIGHTLY_DATE="20240726"
|
ARG NIGHTLY_DATE="20240828"
|
||||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
# Install aiohttp separately to avoid build errors.
|
# Install some basic utilities
|
||||||
RUN pip install aiohttp
|
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||||
# Install NumPy 1 instead of NumPy 2.
|
|
||||||
RUN pip install "numpy<2"
|
|
||||||
# Install the TPU and Pallas dependencies.
|
|
||||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
|
||||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
||||||
|
|
||||||
# Fix FastAPI dependence
|
# Install the TPU and Pallas dependencies.
|
||||||
RUN pip install "starlette<0.38.0"
|
RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
|
|
||||||
# Build vLLM.
|
# Build vLLM.
|
||||||
COPY . /workspace/vllm
|
COPY . /workspace/vllm
|
||||||
ENV VLLM_TARGET_DEVICE="tpu"
|
ENV VLLM_TARGET_DEVICE="tpu"
|
||||||
RUN cd /workspace/vllm && python setup.py develop
|
RUN cd /workspace/vllm && python3 -m pip install -r requirements-tpu.txt
|
||||||
|
RUN cd /workspace/vllm && python3 setup.py develop
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
|||||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip
|
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
WORKDIR /workspace/vllm
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
include LICENSE
|
include LICENSE
|
||||||
include requirements-adag.txt
|
|
||||||
include requirements-common.txt
|
include requirements-common.txt
|
||||||
include requirements-cuda.txt
|
include requirements-cuda.txt
|
||||||
include requirements-rocm.txt
|
include requirements-rocm.txt
|
||||||
|
|||||||
37
README.md
37
README.md
@@ -10,13 +10,23 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> |
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**vLLM, AMD, Anyscale Meet & Greet at [Ray Summit 2024](http://raysummit.anyscale.com) (Monday, Sept 30th, 5-7pm PT) at Marriott Marquis San Francisco**
|
||||||
|
|
||||||
|
We are excited to announce our special vLLM event in collaboration with AMD and Anyscale.
|
||||||
|
Join us to learn more about recent advancements of vLLM on MI300X.
|
||||||
|
Register [here](https://lu.ma/db5ld9n5) and be a part of the event!
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
|
||||||
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||||
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||||
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||||
@@ -36,10 +46,12 @@ vLLM is fast with:
|
|||||||
- Efficient management of attention key and value memory with **PagedAttention**
|
- Efficient management of attention key and value memory with **PagedAttention**
|
||||||
- Continuous batching of incoming requests
|
- Continuous batching of incoming requests
|
||||||
- Fast model execution with CUDA/HIP graph
|
- Fast model execution with CUDA/HIP graph
|
||||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
|
||||||
- Optimized CUDA kernels
|
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
|
||||||
|
- Speculative decoding
|
||||||
|
- Chunked prefill
|
||||||
|
|
||||||
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
|
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
|
||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
@@ -48,20 +60,21 @@ vLLM is flexible and easy to use with:
|
|||||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs
|
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
|
||||||
- (Experimental) Prefix caching support
|
- Prefix caching support
|
||||||
- (Experimental) Multi-lora support
|
- Multi-lora support
|
||||||
|
|
||||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||||
- Transformer-like LLMs (e.g., Llama)
|
- Transformer-like LLMs (e.g., Llama)
|
||||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
||||||
|
- Embedding Models (e.g. E5-Mistral)
|
||||||
- Multi-modal LLMs (e.g., LLaVA)
|
- Multi-modal LLMs (e.g., LLaVA)
|
||||||
|
|
||||||
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install vllm
|
pip install vllm
|
||||||
@@ -99,6 +112,7 @@ vLLM is a community project. Our compute resources for development and testing a
|
|||||||
- Roblox
|
- Roblox
|
||||||
- RunPod
|
- RunPod
|
||||||
- Sequoia Capital
|
- Sequoia Capital
|
||||||
|
- Skywork AI
|
||||||
- Trainy
|
- Trainy
|
||||||
- UC Berkeley
|
- UC Berkeley
|
||||||
- UC San Diego
|
- UC San Diego
|
||||||
@@ -117,3 +131,10 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
|||||||
year={2023}
|
year={2023}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Contact Us
|
||||||
|
|
||||||
|
* For technical questions and feature requests, please use Github issues or discussions.
|
||||||
|
* For discussing with fellow users, please use Discord.
|
||||||
|
* For security disclosures, please use Github's security advisory feature.
|
||||||
|
* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
|
||||||
@@ -24,6 +24,7 @@ class RequestFuncInput:
|
|||||||
model: str
|
model: str
|
||||||
best_of: int = 1
|
best_of: int = 1
|
||||||
use_beam_search: bool = False
|
use_beam_search: bool = False
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -225,8 +226,8 @@ async def async_request_openai_completions(
|
|||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(
|
assert api_url.endswith(
|
||||||
"completions"
|
("completions", "profile")
|
||||||
), "OpenAI Completions API URL must end with 'completions'."
|
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
assert not request_func_input.use_beam_search
|
assert not request_func_input.use_beam_search
|
||||||
@@ -236,6 +237,7 @@ async def async_request_openai_completions(
|
|||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"best_of": request_func_input.best_of,
|
"best_of": request_func_input.best_of,
|
||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
headers = {
|
headers = {
|
||||||
@@ -276,8 +278,9 @@ async def async_request_openai_completions(
|
|||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
output.itl.append(timestamp -
|
else:
|
||||||
most_recent_timestamp)
|
output.itl.append(timestamp -
|
||||||
|
most_recent_timestamp)
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
generated_text += data["choices"][0]["text"]
|
generated_text += data["choices"][0]["text"]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
@@ -205,13 +205,11 @@ if __name__ == '__main__':
|
|||||||
default=None,
|
default=None,
|
||||||
help=('path to save the pytorch profiler output. Can be visualized '
|
help=('path to save the pytorch profiler output. Can be visualized '
|
||||||
'with ui.perfetto.dev or Tensorboard.'))
|
'with ui.perfetto.dev or Tensorboard.'))
|
||||||
parser.add_argument(
|
parser.add_argument("--device",
|
||||||
"--device",
|
type=str,
|
||||||
type=str,
|
default="auto",
|
||||||
default="auto",
|
choices=DEVICE_OPTIONS,
|
||||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
help='device type for vLLM execution')
|
||||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
|
||||||
'CPU.')
|
|
||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
|
|||||||
@@ -1,8 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Benchmark the efficiency of prefix caching.
|
||||||
|
|
||||||
|
This script allows you to benchmark the performance of
|
||||||
|
a model with and without prefix caching using either fixed prompts
|
||||||
|
or prompts sampled from the ShareGPT dataset.
|
||||||
|
|
||||||
|
Fixed example usage:
|
||||||
|
python benchmark_prefix_caching.py \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--num-prompts 1 \
|
||||||
|
--repeat-count 100
|
||||||
|
|
||||||
|
ShareGPT example usage:
|
||||||
|
# This command samples 20 prompts with input lengths
|
||||||
|
# between 128 and 256 tokens from the ShareGPT dataset,
|
||||||
|
# then replicates each prompt 5 times.
|
||||||
|
python benchmark_prefix_caching.py \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--num-prompts 20 \
|
||||||
|
--repeat-count 5 \
|
||||||
|
--input-length-range 128:256
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
except ImportError:
|
||||||
|
from backend_request_func import get_tokenizer
|
||||||
|
|
||||||
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
@@ -15,7 +52,83 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
|
|||||||
print(f"cost time {end_time - start_time}")
|
print(f"cost time {end_time - start_time}")
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
input_length_range: Tuple[int, int],
|
||||||
|
fixed_output_len: Optional[int],
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [(data["conversations"][0]["value"],
|
||||||
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
|
|
||||||
|
# Shuffle the dataset.
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
min_len, max_len = input_length_range
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_dataset) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = dataset[i][0]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = len(completion_token_ids
|
||||||
|
) if fixed_output_len is None else fixed_output_len
|
||||||
|
if prompt_len < 4 or output_len < 4:
|
||||||
|
# Prune too short sequences.
|
||||||
|
continue
|
||||||
|
if min_len <= prompt_len <= max_len:
|
||||||
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
|
||||||
|
repeat_count: int,
|
||||||
|
sort: bool = False) -> List[str]:
|
||||||
|
repeated_requests = requests * repeat_count
|
||||||
|
if sort:
|
||||||
|
repeated_requests.sort(key=lambda x: x[1])
|
||||||
|
else:
|
||||||
|
random.shuffle(repeated_requests)
|
||||||
|
return [req[0] for req in repeated_requests]
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
|
||||||
|
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
||||||
|
|
||||||
|
if args.dataset_path is not None:
|
||||||
|
print(f"Start to sample {args.num_prompts} prompts"
|
||||||
|
"from {args.dataset_path}")
|
||||||
|
filtered_datasets = sample_requests(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_length_range=input_length_range,
|
||||||
|
fixed_output_len=args.output_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_len = len(tokenizer(PROMPT).input_ids)
|
||||||
|
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
|
||||||
|
] * args.num_prompts
|
||||||
|
|
||||||
llm = LLM(model=args.model,
|
llm = LLM(model=args.model,
|
||||||
tokenizer_mode='auto',
|
tokenizer_mode='auto',
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@@ -24,10 +137,13 @@ def main(args):
|
|||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
enable_prefix_caching=args.enable_prefix_caching)
|
enable_prefix_caching=args.enable_prefix_caching)
|
||||||
|
|
||||||
num_prompts = 100
|
|
||||||
prompts = [PROMPT] * num_prompts
|
|
||||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
|
||||||
|
print("Testing filtered datasets")
|
||||||
|
prompts = repeat_and_sort_requests(filtered_datasets,
|
||||||
|
repeat_count=args.repeat_count,
|
||||||
|
sort=args.sort)
|
||||||
|
|
||||||
print("------warm up------")
|
print("------warm up------")
|
||||||
test_prefix(
|
test_prefix(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
@@ -45,11 +161,15 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description='Benchmark the performance with or without automatic '
|
description=
|
||||||
'prefix caching.')
|
'Benchmark the performance with or without automatic prefix caching.')
|
||||||
parser.add_argument('--model',
|
parser.add_argument('--model',
|
||||||
type=str,
|
type=str,
|
||||||
default='baichuan-inc/Baichuan2-13B-Chat')
|
default='baichuan-inc/Baichuan2-13B-Chat')
|
||||||
|
parser.add_argument("--dataset-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the dataset.")
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--output-len', type=int, default=10)
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
parser.add_argument('--enable-prefix-caching',
|
parser.add_argument('--enable-prefix-caching',
|
||||||
@@ -58,5 +178,21 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--use-v2-block-manager',
|
parser.add_argument('--use-v2-block-manager',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Use BlockSpaceMangerV2')
|
help='Use BlockSpaceMangerV2')
|
||||||
|
parser.add_argument('--num-prompts',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of the prompts sampled from dataset")
|
||||||
|
parser.add_argument('--repeat-count',
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help='Number of times to repeat each prompt')
|
||||||
|
parser.add_argument('--sort',
|
||||||
|
action='store_true',
|
||||||
|
help='Sort prompts by input length')
|
||||||
|
parser.add_argument('--input-length-range',
|
||||||
|
type=str,
|
||||||
|
default='128:256',
|
||||||
|
help='Range of input lengths for sampling prompts,'
|
||||||
|
'specified as "min:max" (e.g., "128:256").')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -56,20 +56,27 @@ class BenchmarkMetrics:
|
|||||||
total_input: int
|
total_input: int
|
||||||
total_output: int
|
total_output: int
|
||||||
request_throughput: float
|
request_throughput: float
|
||||||
input_throughput: float
|
|
||||||
output_throughput: float
|
output_throughput: float
|
||||||
|
total_token_throughput: float
|
||||||
mean_ttft_ms: float
|
mean_ttft_ms: float
|
||||||
median_ttft_ms: float
|
median_ttft_ms: float
|
||||||
std_ttft_ms: float
|
std_ttft_ms: float
|
||||||
p99_ttft_ms: float
|
percentiles_ttft_ms: List[Tuple[float, float]]
|
||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
std_tpot_ms: float
|
std_tpot_ms: float
|
||||||
p99_tpot_ms: float
|
percentiles_tpot_ms: List[Tuple[float, float]]
|
||||||
mean_itl_ms: float
|
mean_itl_ms: float
|
||||||
median_itl_ms: float
|
median_itl_ms: float
|
||||||
std_itl_ms: float
|
std_itl_ms: float
|
||||||
p99_itl_ms: float
|
percentiles_itl_ms: List[Tuple[float, float]]
|
||||||
|
# E2EL stands for end-to-end latency per request.
|
||||||
|
# It is the time taken on the client side from sending
|
||||||
|
# a request to receiving a complete response.
|
||||||
|
mean_e2el_ms: float
|
||||||
|
median_e2el_ms: float
|
||||||
|
std_e2el_ms: float
|
||||||
|
percentiles_e2el_ms: List[Tuple[float, float]]
|
||||||
|
|
||||||
|
|
||||||
def sample_sharegpt_requests(
|
def sample_sharegpt_requests(
|
||||||
@@ -188,8 +195,16 @@ def sample_sonnet_requests(
|
|||||||
|
|
||||||
|
|
||||||
def sample_random_requests(
|
def sample_random_requests(
|
||||||
input_len: int, output_len: int, num_prompts: int, range_ratio: float,
|
prefix_len: int,
|
||||||
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
|
input_len: int,
|
||||||
|
output_len: int,
|
||||||
|
num_prompts: int,
|
||||||
|
range_ratio: float,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
prefix_token_ids = np.random.randint(0,
|
||||||
|
tokenizer.vocab_size,
|
||||||
|
size=prefix_len).tolist()
|
||||||
|
|
||||||
input_lens = np.random.randint(
|
input_lens = np.random.randint(
|
||||||
int(input_len * range_ratio),
|
int(input_len * range_ratio),
|
||||||
@@ -204,10 +219,12 @@ def sample_random_requests(
|
|||||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||||
input_requests = []
|
input_requests = []
|
||||||
for i in range(num_prompts):
|
for i in range(num_prompts):
|
||||||
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
|
prompt = tokenizer.decode(prefix_token_ids +
|
||||||
|
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||||
for j in range(input_lens[i])])
|
for j in range(input_lens[i])])
|
||||||
|
|
||||||
input_requests.append(
|
input_requests.append(
|
||||||
(prompt, int(input_lens[i]), int(output_lens[i])))
|
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
|
||||||
|
|
||||||
return input_requests
|
return input_requests
|
||||||
|
|
||||||
@@ -235,6 +252,8 @@ def calculate_metrics(
|
|||||||
outputs: List[RequestFuncOutput],
|
outputs: List[RequestFuncOutput],
|
||||||
dur_s: float,
|
dur_s: float,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
selected_percentile_metrics: List[str],
|
||||||
|
selected_percentiles: List[float],
|
||||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||||
actual_output_lens: List[int] = []
|
actual_output_lens: List[int] = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
@@ -242,6 +261,7 @@ def calculate_metrics(
|
|||||||
itls: List[float] = []
|
itls: List[float] = []
|
||||||
tpots: List[float] = []
|
tpots: List[float] = []
|
||||||
ttfts: List[float] = []
|
ttfts: List[float] = []
|
||||||
|
e2els: List[float] = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
# We use the tokenizer to count the number of output tokens for all
|
# We use the tokenizer to count the number of output tokens for all
|
||||||
@@ -258,6 +278,7 @@ def calculate_metrics(
|
|||||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||||
itls += outputs[i].itl
|
itls += outputs[i].itl
|
||||||
ttfts.append(outputs[i].ttft)
|
ttfts.append(outputs[i].ttft)
|
||||||
|
e2els.append(outputs[i].latency)
|
||||||
completed += 1
|
completed += 1
|
||||||
else:
|
else:
|
||||||
actual_output_lens.append(0)
|
actual_output_lens.append(0)
|
||||||
@@ -272,21 +293,29 @@ def calculate_metrics(
|
|||||||
total_input=total_input,
|
total_input=total_input,
|
||||||
total_output=sum(actual_output_lens),
|
total_output=sum(actual_output_lens),
|
||||||
request_throughput=completed / dur_s,
|
request_throughput=completed / dur_s,
|
||||||
input_throughput=total_input / dur_s,
|
|
||||||
output_throughput=sum(actual_output_lens) / dur_s,
|
output_throughput=sum(actual_output_lens) / dur_s,
|
||||||
|
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||||
1000, # ttfts is empty if streaming is not supported by backend
|
1000, # ttfts is empty if streaming is not supported by backend
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
|
||||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
|
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||||
|
for p in selected_percentiles],
|
||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
|
||||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
|
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||||
|
for p in selected_percentiles],
|
||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
|
||||||
std_itl_ms=np.std(itls or 0) * 1000,
|
std_itl_ms=np.std(itls or 0) * 1000,
|
||||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
|
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||||
|
for p in selected_percentiles],
|
||||||
|
mean_e2el_ms=np.median(e2els or 0) * 1000,
|
||||||
|
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||||
|
median_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||||
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||||
|
for p in selected_percentiles],
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@@ -295,13 +324,18 @@ def calculate_metrics(
|
|||||||
async def benchmark(
|
async def benchmark(
|
||||||
backend: str,
|
backend: str,
|
||||||
api_url: str,
|
api_url: str,
|
||||||
|
base_url: str,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
logprobs: Optional[int],
|
||||||
best_of: int,
|
best_of: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
|
profile: bool,
|
||||||
|
selected_percentile_metrics: List[str],
|
||||||
|
selected_percentiles: List[str],
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@@ -316,6 +350,7 @@ async def benchmark(
|
|||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
use_beam_search=use_beam_search,
|
use_beam_search=use_beam_search,
|
||||||
)
|
)
|
||||||
@@ -326,6 +361,23 @@ async def benchmark(
|
|||||||
f"are correctly specified. Error: {test_output.error}")
|
f"are correctly specified. Error: {test_output.error}")
|
||||||
else:
|
else:
|
||||||
print("Initial test run completed. Starting main benchmark run...")
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
print("Starting profiler...")
|
||||||
|
profile_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=base_url + "/start_profile",
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
best_of=best_of,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
)
|
||||||
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler started")
|
||||||
|
|
||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
@@ -340,6 +392,7 @@ async def benchmark(
|
|||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
output_len=output_len,
|
output_len=output_len,
|
||||||
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
use_beam_search=use_beam_search,
|
use_beam_search=use_beam_search,
|
||||||
)
|
)
|
||||||
@@ -349,6 +402,22 @@ async def benchmark(
|
|||||||
pbar=pbar)))
|
pbar=pbar)))
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
print("Stopping profiler...")
|
||||||
|
profile_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=base_url + "/stop_profile",
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
best_of=best_of,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
)
|
||||||
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler stopped")
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
@@ -359,6 +428,8 @@ async def benchmark(
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
dur_s=benchmark_duration,
|
dur_s=benchmark_duration,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
selected_percentile_metrics=selected_percentile_metrics,
|
||||||
|
selected_percentiles=selected_percentiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||||
@@ -370,27 +441,10 @@ async def benchmark(
|
|||||||
metrics.total_output))
|
metrics.total_output))
|
||||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||||
metrics.request_throughput))
|
metrics.request_throughput))
|
||||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
|
||||||
metrics.input_throughput))
|
|
||||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||||
metrics.output_throughput))
|
metrics.output_throughput))
|
||||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
metrics.total_token_throughput))
|
||||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
|
||||||
metrics.median_ttft_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
|
||||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
|
||||||
n=50,
|
|
||||||
c='-'))
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
|
||||||
metrics.median_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
||||||
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"duration": benchmark_duration,
|
"duration": benchmark_duration,
|
||||||
@@ -398,20 +452,8 @@ async def benchmark(
|
|||||||
"total_input_tokens": metrics.total_input,
|
"total_input_tokens": metrics.total_input,
|
||||||
"total_output_tokens": metrics.total_output,
|
"total_output_tokens": metrics.total_output,
|
||||||
"request_throughput": metrics.request_throughput,
|
"request_throughput": metrics.request_throughput,
|
||||||
"input_throughput": metrics.input_throughput,
|
|
||||||
"output_throughput": metrics.output_throughput,
|
"output_throughput": metrics.output_throughput,
|
||||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
"total_token_throughput": metrics.total_token_throughput,
|
||||||
"median_ttft_ms": metrics.median_ttft_ms,
|
|
||||||
"std_ttft_ms": metrics.std_ttft_ms,
|
|
||||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
||||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
||||||
"median_tpot_ms": metrics.median_tpot_ms,
|
|
||||||
"std_tpot_ms": metrics.std_tpot_ms,
|
|
||||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
||||||
"mean_itl_ms": metrics.mean_itl_ms,
|
|
||||||
"median_itl_ms": metrics.median_itl_ms,
|
|
||||||
"std_itl_ms": metrics.std_itl_ms,
|
|
||||||
"p99_itl_ms": metrics.p99_itl_ms,
|
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
"output_lens": actual_output_lens,
|
"output_lens": actual_output_lens,
|
||||||
"ttfts": [output.ttft for output in outputs],
|
"ttfts": [output.ttft for output in outputs],
|
||||||
@@ -419,6 +461,47 @@ async def benchmark(
|
|||||||
"generated_texts": [output.generated_text for output in outputs],
|
"generated_texts": [output.generated_text for output in outputs],
|
||||||
"errors": [output.error for output in outputs],
|
"errors": [output.error for output in outputs],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def process_one_metric(
|
||||||
|
# E.g., "ttft"
|
||||||
|
metric_attribute_name: str,
|
||||||
|
# E.g., "TTFT"
|
||||||
|
metric_name: str,
|
||||||
|
# E.g., "Time to First Token"
|
||||||
|
metric_header: str,
|
||||||
|
):
|
||||||
|
# This function print and add statistics of the specified
|
||||||
|
# metric.
|
||||||
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
|
return
|
||||||
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||||
|
print("{:<40} {:<10.2f}".format(
|
||||||
|
f"Mean {metric_name} (ms):",
|
||||||
|
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||||
|
print("{:<40} {:<10.2f}".format(
|
||||||
|
f"Median {metric_name} (ms):",
|
||||||
|
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||||
|
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||||
|
metrics, f"mean_{metric_attribute_name}_ms")
|
||||||
|
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||||
|
metrics, f"median_{metric_attribute_name}_ms")
|
||||||
|
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||||
|
metrics, f"std_{metric_attribute_name}_ms")
|
||||||
|
for p, value in getattr(metrics,
|
||||||
|
f"percentiles_{metric_attribute_name}_ms"):
|
||||||
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||||
|
value))
|
||||||
|
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||||
|
|
||||||
|
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||||
|
process_one_metric("tpot", "TPOT",
|
||||||
|
"Time per Output Token (excl. 1st token)")
|
||||||
|
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||||
|
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -433,8 +516,10 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if args.base_url is not None:
|
if args.base_url is not None:
|
||||||
api_url = f"{args.base_url}{args.endpoint}"
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
|
base_url = f"{args.base_url}"
|
||||||
else:
|
else:
|
||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id,
|
tokenizer = get_tokenizer(tokenizer_id,
|
||||||
trust_remote_code=args.trust_remote_code)
|
trust_remote_code=args.trust_remote_code)
|
||||||
@@ -492,6 +577,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
elif args.dataset_name == "random":
|
elif args.dataset_name == "random":
|
||||||
input_requests = sample_random_requests(
|
input_requests = sample_random_requests(
|
||||||
|
prefix_len=args.random_prefix_len,
|
||||||
input_len=args.random_input_len,
|
input_len=args.random_input_len,
|
||||||
output_len=args.random_output_len,
|
output_len=args.random_output_len,
|
||||||
num_prompts=args.num_prompts,
|
num_prompts=args.num_prompts,
|
||||||
@@ -506,13 +592,20 @@ def main(args: argparse.Namespace):
|
|||||||
benchmark(
|
benchmark(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
|
base_url=base_url,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
|
logprobs=args.logprobs,
|
||||||
best_of=args.best_of,
|
best_of=args.best_of,
|
||||||
use_beam_search=args.use_beam_search,
|
use_beam_search=args.use_beam_search,
|
||||||
request_rate=args.request_rate,
|
request_rate=args.request_rate,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
profile=args.profile,
|
||||||
|
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||||
|
selected_percentiles=[
|
||||||
|
float(p) for p in args.metric_percentiles.split(",")
|
||||||
|
],
|
||||||
))
|
))
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
@@ -645,6 +738,16 @@ if __name__ == "__main__":
|
|||||||
help=
|
help=
|
||||||
"Number of output tokens per request, used only for sonnet dataset.",
|
"Number of output tokens per request, used only for sonnet dataset.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logprobs",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help=("Number of logprobs-per-token to compute & return as part of "
|
||||||
|
"the request. If unspecified, then either (1) if beam search "
|
||||||
|
"is disabled, no logprobs are computed & a single dummy "
|
||||||
|
"logprob is returned for each token; or (2) if beam search "
|
||||||
|
"is enabled 1 logprob per token is computed"),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sonnet-prefix-len",
|
"--sonnet-prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -673,6 +776,14 @@ if __name__ == "__main__":
|
|||||||
help="Range of sampled ratio of input/output length, "
|
help="Range of sampled ratio of input/output length, "
|
||||||
"used only for random sampling.",
|
"used only for random sampling.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of fixed prefix tokens before random "
|
||||||
|
" context. The length range of context in a random "
|
||||||
|
" request is [random-prefix-len, "
|
||||||
|
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--request-rate",
|
"--request-rate",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -693,6 +804,12 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Specify to disable tqdm progress bar.",
|
help="Specify to disable tqdm progress bar.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Torch Profiler. The endpoint must be launched with "
|
||||||
|
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-result",
|
"--save-result",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -722,6 +839,23 @@ if __name__ == "__main__":
|
|||||||
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||||
" format.",
|
" format.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--percentile-metrics",
|
||||||
|
type=str,
|
||||||
|
default="ttft,tpot,itl",
|
||||||
|
help="Comma-seperated list of selected metrics to report percentils. "
|
||||||
|
"This argument specifies the metrics to report percentiles. "
|
||||||
|
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||||
|
"Default value is \"ttft,tpot,itl\".")
|
||||||
|
parser.add_argument(
|
||||||
|
"--metric-percentiles",
|
||||||
|
type=str,
|
||||||
|
default="99",
|
||||||
|
help="Comma-seperated list of percentiles for selected metrics. "
|
||||||
|
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||||
|
"Default value is \"99\". "
|
||||||
|
"Use \"--percentile-metrics\" to select metrics.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ import time
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import uvloop
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
|
||||||
|
from vllm.entrypoints.openai.api_server import (
|
||||||
|
build_async_engine_client_from_engine_args)
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
@@ -82,8 +85,11 @@ def run_vllm(
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
distributed_executor_backend: Optional[str],
|
distributed_executor_backend: Optional[str],
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
num_scheduler_steps: int = 1,
|
||||||
|
use_v2_block_manager: bool = False,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
load_format: str = EngineArgs.load_format,
|
load_format: str = EngineArgs.load_format,
|
||||||
|
disable_async_output_proc: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@@ -106,6 +112,9 @@ def run_vllm(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
|
num_scheduler_steps=num_scheduler_steps,
|
||||||
|
use_v2_block_manager=use_v2_block_manager,
|
||||||
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@@ -129,6 +138,93 @@ def run_vllm(
|
|||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
async def run_vllm_async(
|
||||||
|
requests: List[Tuple[str, int, int]],
|
||||||
|
model: str,
|
||||||
|
tokenizer: str,
|
||||||
|
quantization: Optional[str],
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
seed: int,
|
||||||
|
n: int,
|
||||||
|
use_beam_search: bool,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
dtype: str,
|
||||||
|
max_model_len: Optional[int],
|
||||||
|
enforce_eager: bool,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
quantization_param_path: Optional[str],
|
||||||
|
device: str,
|
||||||
|
enable_prefix_caching: bool,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
distributed_executor_backend: Optional[str],
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
num_scheduler_steps: int = 1,
|
||||||
|
use_v2_block_manager: bool = False,
|
||||||
|
download_dir: Optional[str] = None,
|
||||||
|
load_format: str = EngineArgs.load_format,
|
||||||
|
disable_async_output_proc: bool = False,
|
||||||
|
disable_frontend_multiprocessing: bool = False,
|
||||||
|
) -> float:
|
||||||
|
from vllm import SamplingParams
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
quantization=quantization,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
seed=seed,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
quantization_param_path=quantization_param_path,
|
||||||
|
device=device,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
download_dir=download_dir,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
load_format=load_format,
|
||||||
|
num_scheduler_steps=num_scheduler_steps,
|
||||||
|
use_v2_block_manager=use_v2_block_manager,
|
||||||
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
|
worker_use_ray=False,
|
||||||
|
engine_use_ray=False,
|
||||||
|
disable_log_requests=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with build_async_engine_client_from_engine_args(
|
||||||
|
engine_args, disable_frontend_multiprocessing) as llm:
|
||||||
|
|
||||||
|
# Add the requests to the engine.
|
||||||
|
prompts: List[str] = []
|
||||||
|
sampling_params: List[SamplingParams] = []
|
||||||
|
for prompt, _, output_len in requests:
|
||||||
|
prompts.append(prompt)
|
||||||
|
sampling_params.append(
|
||||||
|
SamplingParams(
|
||||||
|
n=n,
|
||||||
|
temperature=0.0 if use_beam_search else 1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
ignore_eos=True,
|
||||||
|
max_tokens=output_len,
|
||||||
|
))
|
||||||
|
|
||||||
|
generators = []
|
||||||
|
start = time.perf_counter()
|
||||||
|
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||||
|
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||||
|
generators.append(generator)
|
||||||
|
all_gens = merge_async_iterators(*generators)
|
||||||
|
async for i, res in all_gens:
|
||||||
|
pass
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
def run_hf(
|
def run_hf(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
model: str,
|
||||||
@@ -224,7 +320,7 @@ def main(args: argparse.Namespace):
|
|||||||
args.output_len)
|
args.output_len)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(
|
run_args = [
|
||||||
requests, args.model, args.tokenizer, args.quantization,
|
requests, args.model, args.tokenizer, args.quantization,
|
||||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||||
@@ -232,7 +328,16 @@ def main(args: argparse.Namespace):
|
|||||||
args.quantization_param_path, args.device,
|
args.quantization_param_path, args.device,
|
||||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||||
args.gpu_memory_utilization, args.download_dir, args.load_format)
|
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||||
|
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||||
|
args.disable_async_output_proc
|
||||||
|
]
|
||||||
|
|
||||||
|
if args.async_engine:
|
||||||
|
run_args.append(args.disable_frontend_multiprocessing)
|
||||||
|
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||||
|
else:
|
||||||
|
elapsed_time = run_vllm(*run_args)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@@ -346,17 +451,23 @@ if __name__ == "__main__":
|
|||||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||||
'instead supported for common inference criteria.')
|
'instead supported for common inference criteria.')
|
||||||
|
parser.add_argument("--device",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=DEVICE_OPTIONS,
|
||||||
|
help='device type for vLLM execution')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--num-scheduler-steps",
|
||||||
type=str,
|
type=int,
|
||||||
default="auto",
|
default=1,
|
||||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
help="Maximum number of forward steps per scheduler call.")
|
||||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
parser.add_argument("--use-v2-block-manager",
|
||||||
'CPU.')
|
action='store_true',
|
||||||
|
help="Enable block manager v2.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="enable automatic prefix caching for vLLM backend.")
|
help="Enable automatic prefix caching for vLLM backend.")
|
||||||
parser.add_argument("--enable-chunked-prefill",
|
parser.add_argument("--enable-chunked-prefill",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="enable chunked prefill for vLLM backend.")
|
help="enable chunked prefill for vLLM backend.")
|
||||||
@@ -405,6 +516,19 @@ if __name__ == "__main__":
|
|||||||
'section for more information.\n'
|
'section for more information.\n'
|
||||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||||
'quantization.\n')
|
'quantization.\n')
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-async-output-proc",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Disable async output processor for vLLM backend.")
|
||||||
|
parser.add_argument("--async-engine",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Use vLLM async engine rather than LLM class.")
|
||||||
|
parser.add_argument("--disable-frontend-multiprocessing",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Disable decoupled async engine frontend.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
a = torch.randn((m, k), device='cuda') * 5
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
b = torch.randn((n, k), device='cuda').t() * 5
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
@@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
|||||||
raise ValueError("unsupported dtype")
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
|
||||||
# impl
|
|
||||||
|
|
||||||
|
|
||||||
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype) -> torch.Tensor:
|
|
||||||
return torch.mm(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype) -> torch.Tensor:
|
|
||||||
return torch._scaled_mm(a,
|
|
||||||
b,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
|
|
||||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype) -> torch.Tensor:
|
|
||||||
return torch._scaled_mm(a,
|
|
||||||
b,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
use_fast_accum=True)
|
|
||||||
|
|
||||||
|
|
||||||
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype) -> torch.Tensor:
|
|
||||||
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||||
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
|
**kwargs) -> TMeasurement:
|
||||||
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
|
||||||
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
"a": a,
|
"args": args,
|
||||||
"b": b,
|
"kwargs": kwargs,
|
||||||
"scale_a": scale_a,
|
|
||||||
"scale_b": scale_b,
|
|
||||||
"out_dtype": out_dtype,
|
|
||||||
"fn": fn,
|
"fn": fn,
|
||||||
}
|
}
|
||||||
return TBenchmark.Timer(
|
return TBenchmark.Timer(
|
||||||
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
|
stmt="fn(*args, **kwargs)",
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
@@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||||
|
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
# pytorch impl - bfloat16
|
# pytorch impl - bfloat16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
torch.mm, a.to(dtype=torch.bfloat16),
|
||||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
b.to(dtype=torch.bfloat16)))
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
|
||||||
|
|
||||||
# pytorch impl - float16
|
# pytorch impl - float16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a.to(dtype=torch.float16, device="cuda"),
|
bench_fn(label, sub_label,
|
||||||
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
|
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||||
torch.float16, label, sub_label, pytorch_mm_impl,
|
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||||
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
|
|
||||||
|
|
||||||
# cutlass impl
|
# cutlass impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||||
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16))
|
||||||
|
|
||||||
|
# cutlass with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||||
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
|
bias))
|
||||||
|
|
||||||
|
# cutlass with azp per-tensor
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
|
||||||
|
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16, azp_adj))
|
||||||
|
|
||||||
|
# cutlass with azp per-tensor + bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
|
||||||
|
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16, azp_adj, None, bias))
|
||||||
|
|
||||||
|
# cutlass with azp per-token
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
|
||||||
|
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16, azp_adj, azp))
|
||||||
|
|
||||||
|
# cutlass with azp per-token + bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
|
||||||
|
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16, azp_adj, azp, bias))
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
@@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
|
|
||||||
# pytorch impl w. bf16
|
# pytorch impl w. bf16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
|
||||||
|
|
||||||
# pytorch impl: bf16 output, without fp8 fast accum
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
bench_fn(label,
|
||||||
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.bfloat16))
|
||||||
|
|
||||||
# pytorch impl: bf16 output, with fp8 fast accum
|
# pytorch impl: bf16 output, with fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
bench_fn(label,
|
||||||
pytorch_fp8_impl_fast_accum,
|
sub_label,
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=True))
|
||||||
|
|
||||||
# pytorch impl: fp16 output, without fp8 fast accum
|
# pytorch impl: fp16 output, without fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
bench_fn(label,
|
||||||
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.float16))
|
||||||
|
|
||||||
# pytorch impl: fp16 output, with fp8 fast accum
|
# pytorch impl: fp16 output, with fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
bench_fn(label,
|
||||||
pytorch_fp8_impl_fast_accum,
|
sub_label,
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.float16,
|
||||||
|
use_fast_accum=True))
|
||||||
|
|
||||||
# cutlass impl: bf16 output
|
# cutlass impl: bf16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||||
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16))
|
||||||
# cutlass impl: fp16 output
|
# cutlass impl: fp16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
|
||||||
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output, with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
|
||||||
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
|
bias))
|
||||||
|
|
||||||
|
# cutlass impl: fp16 output, with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
|
||||||
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
|
||||||
|
bias.to(dtype=torch.float16)))
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
@@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
|
|
||||||
def run(dtype: torch.dtype,
|
def run(dtype: torch.dtype,
|
||||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||||
@@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
|
|||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
base_description: str,
|
base_description: str,
|
||||||
timestamp=None):
|
timestamp=None):
|
||||||
|
|
||||||
print(f"== All Results {base_description} ====")
|
print(f"== All Results {base_description} ====")
|
||||||
print_timers(data)
|
print_timers(data)
|
||||||
|
|
||||||
@@ -251,7 +281,6 @@ def run_range_bench(args):
|
|||||||
|
|
||||||
|
|
||||||
def run_model_bench(args):
|
def run_model_bench(args):
|
||||||
|
|
||||||
print("Benchmarking models:")
|
print("Benchmarking models:")
|
||||||
for i, model in enumerate(args.models):
|
for i, model in enumerate(args.models):
|
||||||
print(f"[{i}] {model}")
|
print(f"[{i}] {model}")
|
||||||
|
|||||||
89
benchmarks/kernels/benchmark_layernorm.py
Normal file
89
benchmarks/kernels/benchmark_layernorm.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main(num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
add_residual: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int = 0,
|
||||||
|
do_profile: bool = False,
|
||||||
|
num_warmup_iters: int = 5,
|
||||||
|
num_iters: int = 100) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
scale = 1 / (2 * hidden_size)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
x *= scale
|
||||||
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
|
||||||
|
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
layer(x, residual)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
|
# Warmup.
|
||||||
|
print("Warming up...")
|
||||||
|
run_benchmark = run_cuda_benchmark
|
||||||
|
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
if do_profile:
|
||||||
|
latency = run_benchmark(num_iters=1, profile=True)
|
||||||
|
else:
|
||||||
|
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||||
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the layernorm kernel.")
|
||||||
|
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||||
|
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||||
|
parser.add_argument("--add-residual", action="store_true")
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="half")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||||
|
parser.add_argument("--num-iters",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of benchmark iterations. "
|
||||||
|
"If --profile is set, this number is ignored")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
main(num_tokens=args.num_tokens,
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
add_residual=args.add_residual,
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
seed=args.seed,
|
||||||
|
do_profile=args.profile,
|
||||||
|
num_warmup_iters=args.num_warmup_iters,
|
||||||
|
num_iters=args.num_iters)
|
||||||
372
benchmarks/kernels/benchmark_machete.py
Normal file
372
benchmarks/kernels/benchmark_machete.py
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from typing import Callable, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
|
MarlinWorkspace)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
gptq_pack, pack_rows, quantize_weights)
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
|
||||||
|
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
|
||||||
|
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||||
|
w_q = w_q.t().contiguous().t() # make col major
|
||||||
|
return ops.machete_prepack_B(w_q, wtype)
|
||||||
|
|
||||||
|
|
||||||
|
def make_bench_tensors(
|
||||||
|
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
|
||||||
|
k: int
|
||||||
|
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
|
||||||
|
torch.tensor]]]:
|
||||||
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||||
|
|
||||||
|
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||||
|
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||||
|
# so we target total weight size > 2*50mb
|
||||||
|
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
|
||||||
|
weights = [
|
||||||
|
torch.randn((k, n), device="cuda", dtype=atype)
|
||||||
|
for _ in range(num_weights)
|
||||||
|
]
|
||||||
|
quanitized_weights = [
|
||||||
|
quantize_weights(w, wtype, group_size) for w in weights
|
||||||
|
]
|
||||||
|
|
||||||
|
return a, quanitized_weights
|
||||||
|
|
||||||
|
|
||||||
|
# impl
|
||||||
|
|
||||||
|
|
||||||
|
# bench
|
||||||
|
def bench_fn(label: str, sub_label: str, description: str,
|
||||||
|
fn: Callable) -> TMeasurement:
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn()",
|
||||||
|
globals={
|
||||||
|
"fn": fn
|
||||||
|
},
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def loop_over_weights(
|
||||||
|
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
|
||||||
|
torch.tensor, torch.tensor]],
|
||||||
|
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
|
||||||
|
None]):
|
||||||
|
for w_ref, w_q, w_s, _ in weights:
|
||||||
|
fn(a, w_ref, w_q, w_s)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(atype: torch.dtype,
|
||||||
|
wtype: ScalarType,
|
||||||
|
group_size: int,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
benchmark_marlinv1: bool = True,
|
||||||
|
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
||||||
|
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
||||||
|
sub_label += f", L={len(weights)}"
|
||||||
|
|
||||||
|
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
|
||||||
|
for w_ref, w_q, w_s, w_zp in weights]
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
# pytorch impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label, sub_label, "torch.matmul", lambda: loop_over_weights(
|
||||||
|
a,
|
||||||
|
weights,
|
||||||
|
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
|
||||||
|
)))
|
||||||
|
|
||||||
|
if benchmark_marlinv1:
|
||||||
|
w_ref = weights[0][0]
|
||||||
|
|
||||||
|
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||||
|
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||||
|
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||||
|
|
||||||
|
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
|
||||||
|
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
|
||||||
|
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
|
||||||
|
wtype.size_bits)
|
||||||
|
|
||||||
|
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
|
||||||
|
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
|
||||||
|
|
||||||
|
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
|
||||||
|
marlinv1_permute_scales(w_s), w_zp)
|
||||||
|
for w_ref, w_q, w_s, w_zp in weights]
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
# marlinv1
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label, sub_label, "marlin_orig", lambda: loop_over_weights(
|
||||||
|
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
|
||||||
|
gptq_marlin_gemm(a,
|
||||||
|
w_q,
|
||||||
|
w_s,
|
||||||
|
w_zp_empty,
|
||||||
|
g_idx,
|
||||||
|
sort_indices,
|
||||||
|
workspace.scratch,
|
||||||
|
wtype,
|
||||||
|
size_m=a.shape[0],
|
||||||
|
size_n=w_ref.shape[1],
|
||||||
|
size_k=w_ref.shape[0],
|
||||||
|
is_k_full=True))))
|
||||||
|
|
||||||
|
# machete
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
|
||||||
|
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
|
||||||
|
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
|
||||||
|
|
||||||
|
if sweep_schedules:
|
||||||
|
print("Finding best schedule for machete")
|
||||||
|
best = None
|
||||||
|
best_schedule = None
|
||||||
|
schedules = ops.machete_supported_schedules(wtype)
|
||||||
|
for schedule in reversed(schedules):
|
||||||
|
|
||||||
|
def run(a, _, w_q, w_s, schedule=schedule):
|
||||||
|
ops.machete_gemm(a,
|
||||||
|
w_q,
|
||||||
|
wtype,
|
||||||
|
w_s,
|
||||||
|
b_group_size=group_size,
|
||||||
|
schedule=schedule)
|
||||||
|
|
||||||
|
res = bench_fn(label, sub_label, "machete_best",
|
||||||
|
lambda: loop_over_weights(a, weights_machete, run))
|
||||||
|
|
||||||
|
print(f" {res.median:5.5} ", schedule)
|
||||||
|
if not best or res.median < best.median:
|
||||||
|
best = res
|
||||||
|
best_schedule = schedule
|
||||||
|
print("Best schedule:", best_schedule)
|
||||||
|
timers.append(best)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype: torch.dtype, sweep_schedules: bool,
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(dtype,
|
||||||
|
scalar_types.uint4b8,
|
||||||
|
128,
|
||||||
|
m,
|
||||||
|
k,
|
||||||
|
n,
|
||||||
|
f"{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})",
|
||||||
|
sweep_schedules=sweep_schedules)
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# output makers
|
||||||
|
def make_output(
|
||||||
|
data: Iterable[TMeasurement],
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
# argparse runners
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||||
|
n = len(dim_sizes)
|
||||||
|
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||||
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_data.extend(d)
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
if dt == "float16":
|
||||||
|
return torch.float16
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Machete GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
required=True,
|
||||||
|
help="Available options are ['bfloat16', 'float16']",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sweep-schedules",
|
||||||
|
action="store_true",
|
||||||
|
help="Run a sweep over all supported schedules",
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
model_parser.add_argument("--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_TP_SIZES)
|
||||||
|
model_parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
@@ -30,19 +30,36 @@ def benchmark_config(
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8: bool,
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
) -> float:
|
) -> float:
|
||||||
init_dtype = torch.float16 if use_fp8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
w1 = torch.randn(num_experts,
|
if use_int8_w8a16:
|
||||||
shard_intermediate_size,
|
w1 = torch.randint(-127,
|
||||||
hidden_size,
|
127, (
|
||||||
dtype=init_dtype)
|
num_experts,
|
||||||
w2 = torch.randn(num_experts,
|
shard_intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
shard_intermediate_size // 2,
|
),
|
||||||
dtype=init_dtype)
|
dtype=torch.int8)
|
||||||
|
w2 = torch.randint(-127,
|
||||||
|
127, (
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
),
|
||||||
|
dtype=torch.int8)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=init_dtype)
|
||||||
|
w2 = torch.randn(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype=init_dtype)
|
||||||
gating_output = torch.randn(num_iters,
|
gating_output = torch.randn(num_iters,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -52,7 +69,11 @@ def benchmark_config(
|
|||||||
w2_scale = None
|
w2_scale = None
|
||||||
a1_scale = None
|
a1_scale = None
|
||||||
a2_scale = None
|
a2_scale = None
|
||||||
if use_fp8:
|
if use_int8_w8a16:
|
||||||
|
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
|
if use_fp8_w8a8:
|
||||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
@@ -76,7 +97,8 @@ def benchmark_config(
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
override_config=config,
|
override_config=config,
|
||||||
use_fp8=use_fp8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
@@ -155,11 +177,13 @@ class BenchmarkWorker:
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8: bool,
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
) -> Tuple[Dict[str, int], float]:
|
) -> Tuple[Dict[str, int], float]:
|
||||||
torch.cuda.manual_seed_all(self.seed)
|
torch.cuda.manual_seed_all(self.seed)
|
||||||
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
dtype_str = "float8" if use_fp8 else None
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||||
@@ -173,7 +197,8 @@ class BenchmarkWorker:
|
|||||||
key=lambda x: abs(x - num_tokens))]
|
key=lambda x: abs(x - num_tokens))]
|
||||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||||
shard_intermediate_size, hidden_size,
|
shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8)
|
topk, dtype, use_fp8_w8a8,
|
||||||
|
use_int8_w8a16)
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
@@ -184,9 +209,10 @@ class BenchmarkWorker:
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8: bool,
|
use_fp8_w8a8: bool,
|
||||||
search_space: List[BenchmarkConfig],
|
use_int8_w8a16: bool,
|
||||||
) -> BenchmarkConfig:
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
for config in tqdm(search_space):
|
for config in tqdm(search_space):
|
||||||
@@ -198,7 +224,8 @@ class BenchmarkWorker:
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
topk,
|
topk,
|
||||||
dtype,
|
dtype,
|
||||||
use_fp8,
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
num_iters=10)
|
num_iters=10)
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
# Some configurations may be invalid and fail to compile.
|
# Some configurations may be invalid and fail to compile.
|
||||||
@@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_configs(
|
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||||
configs: Dict[int, BenchmarkConfig],
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||||
num_experts: int,
|
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||||
shard_intermediate_size: int,
|
use_int8_w8a16: bool) -> None:
|
||||||
hidden_size: int,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
topk: int,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
use_fp8: bool,
|
|
||||||
) -> None:
|
|
||||||
dtype_str = "float8" if use_fp8 else None
|
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
dtype_str)
|
dtype_str)
|
||||||
|
|
||||||
print(f"Writing best config to {filename}...")
|
print(f"Writing best config to {filename}...")
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
json.dump(configs, f, indent=4)
|
json.dump(configs, f, indent=4)
|
||||||
@@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.ffn_config.moe_top_k
|
topk = config.ffn_config.moe_top_k
|
||||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral.
|
# Default: Mixtral.
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
@@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
dtype = config.torch_dtype
|
dtype = config.torch_dtype
|
||||||
use_fp8 = args.dtype == "fp8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size is None:
|
||||||
batch_sizes = [
|
batch_sizes = [
|
||||||
@@ -294,21 +326,21 @@ def main(args: argparse.Namespace):
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8, search_space)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
best_configs = {
|
best_configs = {
|
||||||
M: sort_config(config)
|
M: sort_config(config)
|
||||||
for M, config in zip(batch_sizes, configs)
|
for M, config in zip(batch_sizes, configs)
|
||||||
}
|
}
|
||||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Tuning took {end - start:.2f} seconds")
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
else:
|
else:
|
||||||
outputs = _distribute("benchmark",
|
outputs = _distribute(
|
||||||
[(batch_size, E, shard_intermediate_size,
|
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
hidden_size, topk, dtype, use_fp8)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
print(f"Batch size: {batch_size}, config: {config}")
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
@@ -323,7 +355,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
parser.add_argument("--dtype",
|
parser.add_argument("--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||||
default="auto")
|
default="auto")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
|||||||
103
benchmarks/kernels/benchmark_quant.py
Normal file
103
benchmarks/kernels/benchmark_quant.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main(num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
static_scale: bool,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int = 0,
|
||||||
|
do_profile: bool = False,
|
||||||
|
num_warmup_iters: int = 5,
|
||||||
|
num_iters: int = 100) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
|
||||||
|
|
||||||
|
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
if quant_dtype == torch.int8:
|
||||||
|
ops.scaled_int8_quant(x, scale)
|
||||||
|
else:
|
||||||
|
ops.scaled_fp8_quant(x, scale)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
|
# Warmup.
|
||||||
|
print("Warming up...")
|
||||||
|
run_benchmark = run_cuda_benchmark
|
||||||
|
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
if do_profile:
|
||||||
|
latency = run_benchmark(num_iters=1, profile=True)
|
||||||
|
else:
|
||||||
|
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||||
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError(f"Unsupported dtype: {dt}")
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the quantization (fp8 or int8) kernel.")
|
||||||
|
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||||
|
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||||
|
parser.add_argument("--static-scale", action="store_true")
|
||||||
|
parser.add_argument("--quant-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp8", "int8"],
|
||||||
|
default="int8")
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="half")
|
||||||
|
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||||
|
parser.add_argument("--num-iters",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of benchmark iterations. "
|
||||||
|
"If --profile is set, this number is ignored")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
main(num_tokens=args.num_tokens,
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
static_scale=args.static_scale,
|
||||||
|
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
seed=args.seed,
|
||||||
|
do_profile=args.profile,
|
||||||
|
num_warmup_iters=args.num_warmup_iters,
|
||||||
|
num_iters=args.num_iters)
|
||||||
64
benchmarks/kernels/graph_machete_bench.py
Normal file
64
benchmarks/kernels/graph_machete_bench.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import math
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description='Benchmark the latency of processing a single batch of '
|
||||||
|
'requests till completion.')
|
||||||
|
parser.add_argument('filename', type=str)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.filename, 'rb') as f:
|
||||||
|
data: List[TMeasurement] = pickle.load(f)
|
||||||
|
|
||||||
|
results = defaultdict(lambda: list())
|
||||||
|
for v in data:
|
||||||
|
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
||||||
|
if result is not None:
|
||||||
|
KN = result.group(1)
|
||||||
|
else:
|
||||||
|
raise Exception("MKN not found")
|
||||||
|
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
|
||||||
|
if result is not None:
|
||||||
|
M = result.group(1)
|
||||||
|
else:
|
||||||
|
raise Exception("MKN not found")
|
||||||
|
|
||||||
|
kernel = v.task_spec.description
|
||||||
|
results[KN].append({
|
||||||
|
"kernel": kernel,
|
||||||
|
"batch_size": M,
|
||||||
|
"median": v.median
|
||||||
|
})
|
||||||
|
|
||||||
|
rows = int(math.ceil(len(results) / 2))
|
||||||
|
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
||||||
|
axs = axs.flatten()
|
||||||
|
axs_idx = 0
|
||||||
|
for shape, data in results.items():
|
||||||
|
plt.sca(axs[axs_idx])
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
sns.lineplot(data=df,
|
||||||
|
x="batch_size",
|
||||||
|
y="median",
|
||||||
|
hue="kernel",
|
||||||
|
style="kernel",
|
||||||
|
markers=True,
|
||||||
|
dashes=False,
|
||||||
|
palette="Dark2")
|
||||||
|
plt.title(f"Shape: {shape}")
|
||||||
|
plt.ylabel("time (median, s)")
|
||||||
|
axs_idx += 1
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("graph_machete_bench.pdf")
|
||||||
43
benchmarks/kernels/weight_shapes.py
Normal file
43
benchmarks/kernels/weight_shapes.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Weight Shapes are in the format
|
||||||
|
# ([K, N], TP_SPLIT_DIM)
|
||||||
|
# Example:
|
||||||
|
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 14336, N = 4096
|
||||||
|
# - TP2 : K = 7168, N = 4096
|
||||||
|
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 4096, N = 6144
|
||||||
|
# - TP4 : K = 4096, N = 1536
|
||||||
|
|
||||||
|
# TP1 shapes
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf": [
|
||||||
|
([4096, 12288], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 22016], 1),
|
||||||
|
([11008, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3-8b": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf": [
|
||||||
|
([5120, 15360], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 27648], 1),
|
||||||
|
([13824, 5120], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ TOKENS=$2
|
|||||||
|
|
||||||
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
||||||
-v $PWD/data:/data \
|
-v $PWD/data:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
||||||
--model-id $MODEL \
|
--model-id $MODEL \
|
||||||
--sharded false \
|
--sharded false \
|
||||||
--max-input-length 1024 \
|
--max-input-length 1024 \
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Define environment variables for special configurations
|
# Define environment variables for special configurations
|
||||||
@@ -83,12 +84,7 @@ endif()
|
|||||||
|
|
||||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||||
|
|
||||||
list(APPEND LIBS "numa")
|
list(APPEND LIBS dnnl numa)
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Define extension targets
|
|
||||||
#
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# _C extension
|
# _C extension
|
||||||
@@ -102,6 +98,16 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/pos_encoding.cpp"
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
"csrc/cpu/torch_bindings.cpp")
|
"csrc/cpu/torch_bindings.cpp")
|
||||||
|
|
||||||
|
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||||
|
set(VLLM_EXT_SRC
|
||||||
|
"csrc/cpu/quant.cpp"
|
||||||
|
${VLLM_EXT_SRC})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Define extension targets
|
||||||
|
#
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
_C
|
_C
|
||||||
DESTINATION vllm
|
DESTINATION vllm
|
||||||
|
|||||||
@@ -350,6 +350,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
|||||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||||
${GPU_INCLUDE_DIRECTORIES})
|
${GPU_INCLUDE_DIRECTORIES})
|
||||||
|
|
||||||
|
# TODO: is torch_python_LIBRARY needed?
|
||||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
|
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
|
||||||
${GPU_LIBRARIES})
|
${GPU_LIBRARIES})
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,8 @@ DEFAULT_CONDA_PATTERNS = {
|
|||||||
"nccl",
|
"nccl",
|
||||||
"transformers",
|
"transformers",
|
||||||
"zmq",
|
"zmq",
|
||||||
|
"nvidia",
|
||||||
|
"pynvml",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_PIP_PATTERNS = {
|
DEFAULT_PIP_PATTERNS = {
|
||||||
@@ -79,6 +81,8 @@ DEFAULT_PIP_PATTERNS = {
|
|||||||
"nccl",
|
"nccl",
|
||||||
"transformers",
|
"transformers",
|
||||||
"zmq",
|
"zmq",
|
||||||
|
"nvidia",
|
||||||
|
"pynvml",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -265,8 +269,9 @@ def get_neuron_sdk_version(run_lambda):
|
|||||||
def get_vllm_version():
|
def get_vllm_version():
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
return vllm.__version__
|
return vllm.__version__ + "@" + vllm.__commit__
|
||||||
except ImportError:
|
except Exception:
|
||||||
|
# old version of vllm does not have __commit__
|
||||||
return 'N/A'
|
return 'N/A'
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
|||||||
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 1; ii < N; ++ii) {
|
for (int ii = 1; ii < N; ++ii) {
|
||||||
qk_vec = fma(q[ii], k[ii], qk_vec);
|
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finalize the reduction across lanes.
|
// Finalize the reduction across lanes.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ namespace vllm {
|
|||||||
//
|
//
|
||||||
class ScalarType {
|
class ScalarType {
|
||||||
public:
|
public:
|
||||||
enum NanRepr : int64_t {
|
enum NanRepr : uint8_t {
|
||||||
NAN_NONE = 0, // nans are not supported
|
NAN_NONE = 0, // nans are not supported
|
||||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||||
@@ -28,33 +28,33 @@ class ScalarType {
|
|||||||
NAN_REPR_ID_MAX
|
NAN_REPR_ID_MAX
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
|
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||||
int64_t bias, bool finite_values_only = false,
|
int32_t bias, bool finite_values_only = false,
|
||||||
NanRepr nan_repr = NAN_IEEE_754)
|
NanRepr nan_repr = NAN_IEEE_754)
|
||||||
: exponent(exponent),
|
: exponent(exponent),
|
||||||
mantissa(mantissa),
|
mantissa(mantissa),
|
||||||
bias(bias),
|
|
||||||
signed_(signed_),
|
signed_(signed_),
|
||||||
|
bias(bias),
|
||||||
finite_values_only(finite_values_only),
|
finite_values_only(finite_values_only),
|
||||||
nan_repr(nan_repr){};
|
nan_repr(nan_repr){};
|
||||||
|
|
||||||
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
|
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||||
return ScalarType(true, 0, size_bits - 1, bias);
|
return ScalarType(0, size_bits - 1, true, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
|
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||||
return ScalarType(false, 0, size_bits, bias);
|
return ScalarType(0, size_bits, false, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
// IEEE 754 compliant floating point type
|
// IEEE 754 compliant floating point type
|
||||||
static constexpr ScalarType float_IEEE754(int64_t exponent,
|
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
||||||
int64_t mantissa) {
|
uint8_t mantissa) {
|
||||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||||
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
|
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||||
}
|
}
|
||||||
|
|
||||||
// IEEE 754 non-compliant floating point type
|
// IEEE 754 non-compliant floating point type
|
||||||
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
|
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||||
bool finite_values_only,
|
bool finite_values_only,
|
||||||
NanRepr nan_repr) {
|
NanRepr nan_repr) {
|
||||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||||
@@ -62,36 +62,121 @@ class ScalarType {
|
|||||||
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
||||||
"use `float_IEEE754` constructor for floating point types that "
|
"use `float_IEEE754` constructor for floating point types that "
|
||||||
"follow IEEE 754 conventions");
|
"follow IEEE 754 conventions");
|
||||||
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
|
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||||
nan_repr);
|
nan_repr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t const exponent; // size of the exponent field (0 for integer types)
|
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||||
int64_t const mantissa; // size of the mantissa field (size of the integer
|
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||||
// excluding the sign bit for integer types)
|
// excluding the sign bit for integer types)
|
||||||
int64_t const bias; // stored values equal value + bias,
|
|
||||||
// used for quantized type
|
|
||||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||||
// sign bit)
|
// sign bit)
|
||||||
|
int32_t const bias; // stored values equal value + bias,
|
||||||
|
// used for quantized type
|
||||||
|
|
||||||
// Extra Floating point info
|
// Extra Floating point info
|
||||||
bool const finite_values_only; // i.e. no +/-inf if true
|
bool const finite_values_only; // i.e. no +/-inf if true
|
||||||
NanRepr const nan_repr; // how NaNs are represented
|
NanRepr const nan_repr; // how NaNs are represented
|
||||||
// (not applicable for integer types)
|
// (not applicable for integer types)
|
||||||
|
|
||||||
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
|
using Id = int64_t;
|
||||||
bool is_signed() const { return signed_; }
|
|
||||||
bool is_integer() const { return exponent == 0; }
|
private:
|
||||||
bool is_floating_point() const { return exponent > 0; }
|
// Field size in id
|
||||||
bool is_ieee_754() const {
|
template <typename T_>
|
||||||
|
static constexpr size_t member_id_field_width() {
|
||||||
|
using T = std::decay_t<T_>;
|
||||||
|
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||||
|
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||||
|
Rest... rest) {
|
||||||
|
auto new_val = f(val, member);
|
||||||
|
if constexpr (sizeof...(rest) > 0) {
|
||||||
|
return reduce_members_helper(f, new_val, rest...);
|
||||||
|
} else {
|
||||||
|
return new_val;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, typename Init>
|
||||||
|
constexpr auto reduce_members(Fn f, Init init) const {
|
||||||
|
// Should be in constructor order for `from_id`
|
||||||
|
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||||
|
finite_values_only, nan_repr);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Init>
|
||||||
|
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||||
|
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||||
|
return dummy_type.reduce_members(f, init);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr auto id_size_bits() {
|
||||||
|
return reduce_member_types(
|
||||||
|
[](int acc, auto member) -> int {
|
||||||
|
return acc + member_id_field_width<decltype(member)>();
|
||||||
|
},
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// unique id for this scalar type that can be computed at compile time for
|
||||||
|
// c++17 template specialization this is not needed once we migrate to
|
||||||
|
// c++20 and can pass literal classes as template parameters
|
||||||
|
constexpr Id id() const {
|
||||||
|
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
||||||
|
"ScalarType id is too large to be stored");
|
||||||
|
|
||||||
|
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
||||||
|
auto member) -> std::pair<Id, uint32_t> {
|
||||||
|
auto [id, bit_offset] = result;
|
||||||
|
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||||
|
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
||||||
|
<< bit_offset,
|
||||||
|
bit_offset + bits};
|
||||||
|
};
|
||||||
|
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a ScalarType from an id, for c++17 template specialization,
|
||||||
|
// this is not needed once we migrate to c++20 and can pass literal
|
||||||
|
// classes as template parameters
|
||||||
|
static constexpr ScalarType from_id(Id id) {
|
||||||
|
auto extract_and_advance = [id](auto result, auto member) {
|
||||||
|
using T = decltype(member);
|
||||||
|
auto [tuple, bit_offset] = result;
|
||||||
|
auto constexpr bits = member_id_field_width<T>();
|
||||||
|
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
||||||
|
((uint64_t(1) << bits) - 1));
|
||||||
|
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||||
|
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
||||||
|
std::pair<std::tuple<>, int>{});
|
||||||
|
return std::apply([](auto... args) { return ScalarType(args...); },
|
||||||
|
tuple_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int64_t size_bits() const {
|
||||||
|
return mantissa + exponent + is_signed();
|
||||||
|
}
|
||||||
|
constexpr bool is_signed() const { return signed_; }
|
||||||
|
constexpr bool is_integer() const { return exponent == 0; }
|
||||||
|
constexpr bool is_floating_point() const { return exponent > 0; }
|
||||||
|
constexpr bool is_ieee_754() const {
|
||||||
return is_floating_point() && finite_values_only == false &&
|
return is_floating_point() && finite_values_only == false &&
|
||||||
nan_repr == NAN_IEEE_754;
|
nan_repr == NAN_IEEE_754;
|
||||||
}
|
}
|
||||||
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
|
constexpr bool has_nans() const {
|
||||||
bool has_infs() const {
|
return is_floating_point() && nan_repr != NAN_NONE;
|
||||||
|
}
|
||||||
|
constexpr bool has_infs() const {
|
||||||
return is_floating_point() && finite_values_only == false;
|
return is_floating_point() && finite_values_only == false;
|
||||||
}
|
}
|
||||||
bool has_bias() const { return bias != 0; }
|
constexpr bool has_bias() const { return bias != 0; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double _floating_point_max() const {
|
double _floating_point_max() const {
|
||||||
@@ -131,7 +216,7 @@ class ScalarType {
|
|||||||
return *reinterpret_cast<double*>(&double_raw);
|
return *reinterpret_cast<double*>(&double_raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::variant<int64_t, double> _raw_max() const {
|
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||||
if (is_floating_point()) {
|
if (is_floating_point()) {
|
||||||
return {_floating_point_max()};
|
return {_floating_point_max()};
|
||||||
} else {
|
} else {
|
||||||
@@ -141,7 +226,7 @@ class ScalarType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::variant<int64_t, double> _raw_min() const {
|
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||||
if (is_floating_point()) {
|
if (is_floating_point()) {
|
||||||
TORCH_CHECK(is_signed(),
|
TORCH_CHECK(is_signed(),
|
||||||
"We currently assume all floating point types are signed");
|
"We currently assume all floating point types are signed");
|
||||||
@@ -168,7 +253,7 @@ class ScalarType {
|
|||||||
public:
|
public:
|
||||||
// Max representable value for this scalar type.
|
// Max representable value for this scalar type.
|
||||||
// (accounting for bias if there is one)
|
// (accounting for bias if there is one)
|
||||||
std::variant<int64_t, double> max() const {
|
constexpr std::variant<int64_t, double> max() const {
|
||||||
return std::visit(
|
return std::visit(
|
||||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||||
_raw_max());
|
_raw_max());
|
||||||
@@ -176,7 +261,7 @@ class ScalarType {
|
|||||||
|
|
||||||
// Min representable value for this scalar type.
|
// Min representable value for this scalar type.
|
||||||
// (accounting for bias if there is one)
|
// (accounting for bias if there is one)
|
||||||
std::variant<int64_t, double> min() const {
|
constexpr std::variant<int64_t, double> min() const {
|
||||||
return std::visit(
|
return std::visit(
|
||||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||||
_raw_min());
|
_raw_min());
|
||||||
@@ -215,7 +300,7 @@ class ScalarType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(ScalarType const& other) const {
|
constexpr bool operator==(ScalarType const& other) const {
|
||||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||||
bias == other.bias && signed_ == other.signed_ &&
|
bias == other.bias && signed_ == other.signed_ &&
|
||||||
finite_values_only == other.finite_values_only &&
|
finite_values_only == other.finite_values_only &&
|
||||||
@@ -228,6 +313,8 @@ class ScalarType {
|
|||||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||||
// constexpr destructor)
|
// constexpr destructor)
|
||||||
|
// See also:
|
||||||
|
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||||
public:
|
public:
|
||||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||||
@@ -240,31 +327,91 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
using Self = ScalarTypeTorch;
|
using Self = ScalarTypeTorch;
|
||||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
using SelfPtr = c10::intrusive_ptr<Self>;
|
||||||
|
|
||||||
|
static void check_size_bits(int64_t size_bits, bool signed_) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
size_bits <=
|
||||||
|
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||||
|
"size_bits bit width is too large to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void check_bias(int64_t bias) {
|
||||||
|
using Bias = decltype(std::declval<Self>().bias);
|
||||||
|
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
|
||||||
|
bias >= std::numeric_limits<Bias>::min(),
|
||||||
|
"bias too large or small to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void check_exponent(int64_t exponent) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
exponent <=
|
||||||
|
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
|
||||||
|
"exponent bit width is too large to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void check_mantissa(int64_t mantissa) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
mantissa <=
|
||||||
|
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||||
|
"mantissa bit width is too large to be represented");
|
||||||
|
}
|
||||||
|
|
||||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||||
|
check_size_bits(size_bits, true);
|
||||||
|
check_bias(bias.value_or(0));
|
||||||
return c10::make_intrusive<Self>(
|
return c10::make_intrusive<Self>(
|
||||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
ScalarType::int_(size_bits, bias.value_or(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||||
|
check_size_bits(size_bits, true);
|
||||||
|
check_bias(bias.value_or(0));
|
||||||
return c10::make_intrusive<Self>(
|
return c10::make_intrusive<Self>(
|
||||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
ScalarType::uint(size_bits, bias.value_or(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
||||||
|
check_mantissa(mantissa);
|
||||||
|
check_exponent(exponent);
|
||||||
return c10::make_intrusive<Self>(
|
return c10::make_intrusive<Self>(
|
||||||
ScalarType::float_IEEE754(exponent, mantissa));
|
ScalarType::float_IEEE754(exponent, mantissa));
|
||||||
}
|
}
|
||||||
|
|
||||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
||||||
bool finite_values_only, int64_t nan_repr) {
|
bool finite_values_only, int64_t nan_repr) {
|
||||||
|
check_mantissa(mantissa);
|
||||||
|
check_exponent(exponent);
|
||||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
return c10::make_intrusive<Self>(ScalarType::float_(
|
||||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This needs to be implemented and throw a TypeError in order for
|
||||||
|
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||||
|
int64_t len() const {
|
||||||
|
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
||||||
|
"__len__ not implemented");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize a ScalarType into a tuple of pairs. Where each pair
|
||||||
|
// is a (fieldname, value).
|
||||||
|
// For simplicity, we are just going to convert to a ScalarTypeId.
|
||||||
|
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
|
||||||
|
return {{"ScalarType", id()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize a scalar type that has been serialized by obj_flatten,
|
||||||
|
// ostensibly from a tuple of (member name, value) pairs, but in reality
|
||||||
|
// just a ScalarTypeId.
|
||||||
|
static SelfPtr obj_unflatten(
|
||||||
|
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
|
||||||
|
return c10::make_intrusive<Self>(
|
||||||
|
from_id(std::get<1>(std::get<0>(flat_type))));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||||
std::string const& name, T Base::*field) {
|
std::string const& name, T Base::*field) {
|
||||||
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
|
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
|
||||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
||||||
return (self.get()->*field)();
|
return (self.get()->*field)();
|
||||||
} else {
|
} else {
|
||||||
@@ -272,6 +419,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto getter_func = [field = std::move(field),
|
||||||
|
getter_func_helper = std::move(getter_func_helper)](
|
||||||
|
SelfPtr const& self) {
|
||||||
|
auto val = getter_func_helper(self);
|
||||||
|
// upconvert uint8_t, int32_t etc. to int64_t for python
|
||||||
|
if constexpr (std::is_integral_v<T>) {
|
||||||
|
return static_cast<int64_t>(val);
|
||||||
|
} else {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
cls.def_property(name, getter_func);
|
cls.def_property(name, getter_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,6 +483,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
self.get()->min());
|
self.get()->min());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
bind_function(cls, "__len__", &ScalarTypeTorch::len);
|
||||||
bind_function(cls, "__str__", &Base::str);
|
bind_function(cls, "__str__", &Base::str);
|
||||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||||
return *self == *other;
|
return *self == *other;
|
||||||
@@ -332,6 +492,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
return "ScalarType." + self.get()->str();
|
return "ScalarType." + self.get()->str();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
|
||||||
|
bind_static_function(cls, "__obj_unflatten__",
|
||||||
|
&ScalarTypeTorch::obj_unflatten);
|
||||||
|
|
||||||
// Bind static functions (convenience constructors)
|
// Bind static functions (convenience constructors)
|
||||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||||
@@ -340,6 +504,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ScalarTypeId = int64_t;
|
||||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
||||||
|
|
||||||
// "rust style" names generally following:
|
// "rust style" names generally following:
|
||||||
@@ -379,4 +544,5 @@ static inline constexpr auto kHalf = kFE5M10;
|
|||||||
static inline constexpr auto kFloat16 = kHalf;
|
static inline constexpr auto kFloat16 = kHalf;
|
||||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||||
|
|
||||||
|
static inline constexpr auto kFloat16Id = kFloat16.id();
|
||||||
}; // namespace vllm
|
}; // namespace vllm
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ namespace vec_op {
|
|||||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
#else
|
#else
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||||
std::cout << #NAME << " invoked." << std::endl;
|
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
@@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
explicit BF16Vec16(const FP32Vec16 &);
|
explicit BF16Vec16(const FP32Vec16 &);
|
||||||
|
|
||||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||||
|
|
||||||
|
void save(void* ptr, const int elem_num) const {
|
||||||
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
|
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||||
|
_mm256_mask_storeu_epi16(ptr, mask, reg);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#ifdef __AVX512F__
|
||||||
@@ -313,8 +319,28 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||||
|
return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg)));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec16 max(const FP32Vec16& b) const {
|
||||||
|
return FP32Vec16(_mm512_max_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||||
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
|
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||||
|
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec16 abs() const {
|
||||||
|
return FP32Vec16(_mm512_abs_ps(reg));
|
||||||
|
}
|
||||||
|
|
||||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||||
|
|
||||||
|
float reduce_max() const { return _mm512_reduce_max_ps(reg); }
|
||||||
|
|
||||||
template <int group_size> float reduce_sub_sum(int idx) {
|
template <int group_size> float reduce_sub_sum(int idx) {
|
||||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||||
@@ -323,6 +349,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||||
|
|
||||||
|
void save(float* ptr, const int elem_num) const {
|
||||||
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
|
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||||
|
_mm512_mask_storeu_ps(ptr, mask, reg);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||||
@@ -433,6 +465,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
struct INT8Vec16: public Vec<INT8Vec16> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
|
union AliasReg {
|
||||||
|
__m128i reg;
|
||||||
|
int8_t values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
|
||||||
|
__m128i reg;
|
||||||
|
|
||||||
|
explicit INT8Vec16(const FP32Vec16& vec) : reg(
|
||||||
|
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||||
|
) {}
|
||||||
|
|
||||||
|
void save(int8_t* ptr) const {
|
||||||
|
_mm_storeu_epi8(ptr, reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(int8_t* ptr, const int elem_num) const {
|
||||||
|
constexpr uint32_t M = 0xFFFFFFFF;
|
||||||
|
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||||
|
_mm_mask_storeu_epi8(ptr, mask, reg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename T> struct VecType { using vec_type = void; };
|
template <typename T> struct VecType { using vec_type = void; };
|
||||||
|
|
||||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||||
|
|||||||
168
csrc/cpu/dnnl_helper.hpp
Normal file
168
csrc/cpu/dnnl_helper.hpp
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
#ifndef DNNL_HELPER_HPP
|
||||||
|
#define DNNL_HELPER_HPP
|
||||||
|
|
||||||
|
#include <c10/util/BFloat16.h>
|
||||||
|
|
||||||
|
#include "oneapi/dnnl/dnnl.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
struct DNNLType {
|
||||||
|
static constexpr dnnl::memory::data_type type =
|
||||||
|
dnnl::memory::data_type::undef;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DNNLType<int8_t> {
|
||||||
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DNNLType<int32_t> {
|
||||||
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DNNLType<float> {
|
||||||
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DNNLType<c10::BFloat16> {
|
||||||
|
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||||
|
return DNNLType<std::decay_t<T>>::type;
|
||||||
|
}
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
template <bool InputNoScale>
|
||||||
|
class DNNLPrimitiveHelper {
|
||||||
|
public:
|
||||||
|
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
||||||
|
// A: [M, K], row-major
|
||||||
|
// B: [K, N], column-major
|
||||||
|
// C: [M, N], row-major
|
||||||
|
// bias: [N], row-major, optional
|
||||||
|
// a_scales: [MS]
|
||||||
|
// b_scales: [NS]
|
||||||
|
// Note: Due to the limitation of oneDNN
|
||||||
|
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||||
|
// not supported.
|
||||||
|
template <typename OutputT, typename BiasT>
|
||||||
|
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||||
|
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||||
|
dnnl_dim_t K, const float* a_scales,
|
||||||
|
const float* b_scales, dnnl_dim_t MS,
|
||||||
|
dnnl_dim_t NS) {
|
||||||
|
auto&& OutputType = get_dnnl_type<OutputT>();
|
||||||
|
auto&& BiasType = get_dnnl_type<BiasT>();
|
||||||
|
|
||||||
|
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
||||||
|
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
||||||
|
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
||||||
|
|
||||||
|
dnnl::primitive_attr attr;
|
||||||
|
if constexpr (!InputNoScale) {
|
||||||
|
if (MS == 1) {
|
||||||
|
// per-tensor
|
||||||
|
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||||
|
} else {
|
||||||
|
// per-token
|
||||||
|
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (NS == 1) {
|
||||||
|
// per-tensor
|
||||||
|
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||||
|
} else {
|
||||||
|
// per-channel
|
||||||
|
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
dnnl::matmul::primitive_desc matmul_pd;
|
||||||
|
if (bias) {
|
||||||
|
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||||
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||||
|
bias_md, c_md, attr);
|
||||||
|
} else {
|
||||||
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||||
|
c_md, attr);
|
||||||
|
}
|
||||||
|
dnnl::matmul matmul(matmul_pd);
|
||||||
|
|
||||||
|
auto& engine = default_engine();
|
||||||
|
|
||||||
|
dnnl::memory a_m(a_md, engine, (void*)a);
|
||||||
|
dnnl::memory b_m(b_md, engine, (void*)b);
|
||||||
|
dnnl::memory c_m(c_md, engine, (void*)c);
|
||||||
|
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||||
|
(void*)a_scales);
|
||||||
|
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||||
|
(void*)b_scales);
|
||||||
|
|
||||||
|
auto& stream = default_stream();
|
||||||
|
if constexpr (InputNoScale) {
|
||||||
|
if (bias) {
|
||||||
|
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||||
|
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||||
|
matmul.execute(
|
||||||
|
stream, {
|
||||||
|
{DNNL_ARG_SRC, a_m},
|
||||||
|
{DNNL_ARG_WEIGHTS, b_m},
|
||||||
|
{DNNL_ARG_BIAS, bias_m},
|
||||||
|
{DNNL_ARG_DST, c_m},
|
||||||
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
matmul.execute(
|
||||||
|
stream, {
|
||||||
|
{DNNL_ARG_SRC, a_m},
|
||||||
|
{DNNL_ARG_WEIGHTS, b_m},
|
||||||
|
{DNNL_ARG_DST, c_m},
|
||||||
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (bias) {
|
||||||
|
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||||
|
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||||
|
matmul.execute(
|
||||||
|
stream, {
|
||||||
|
{DNNL_ARG_SRC, a_m},
|
||||||
|
{DNNL_ARG_WEIGHTS, b_m},
|
||||||
|
{DNNL_ARG_BIAS, bias_m},
|
||||||
|
{DNNL_ARG_DST, c_m},
|
||||||
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||||
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
matmul.execute(
|
||||||
|
stream, {
|
||||||
|
{DNNL_ARG_SRC, a_m},
|
||||||
|
{DNNL_ARG_WEIGHTS, b_m},
|
||||||
|
{DNNL_ARG_DST, c_m},
|
||||||
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||||
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static dnnl::engine& default_engine() {
|
||||||
|
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||||
|
return engine;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dnnl::stream& default_stream() {
|
||||||
|
static dnnl::stream stream(default_engine());
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
294
csrc/cpu/quant.cpp
Normal file
294
csrc/cpu/quant.cpp
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
#include "cpu_types.hpp"
|
||||||
|
#include "dnnl_helper.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct KernelVecType {
|
||||||
|
using load_vec_type = void;
|
||||||
|
using cvt_vec_type = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<float> {
|
||||||
|
using load_vec_type = vec_op::FP32Vec16;
|
||||||
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
|
using load_vec_type = vec_op::BF16Vec16;
|
||||||
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
template <typename scalar_t>
|
||||||
|
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||||
|
const float* scale, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||||
|
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||||
|
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
constexpr float i8_min =
|
||||||
|
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||||
|
constexpr float i8_max =
|
||||||
|
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||||
|
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||||
|
const cvt_vec_t i8_min_vec(i8_min);
|
||||||
|
const cvt_vec_t i8_max_vec(i8_max);
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < num_tokens; ++i) {
|
||||||
|
int j = 0;
|
||||||
|
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||||
|
load_vec_t elems(input + i * hidden_size + j);
|
||||||
|
cvt_vec_t elems_fp32(elems);
|
||||||
|
elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
|
||||||
|
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||||
|
elems_int8.save(output + i * hidden_size + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_vec_t elems(input + i * hidden_size + j);
|
||||||
|
cvt_vec_t elems_fp32(elems);
|
||||||
|
elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
|
||||||
|
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||||
|
|
||||||
|
if (j + vec_elem_num == hidden_size) {
|
||||||
|
elems_int8.save(output + i * hidden_size + j);
|
||||||
|
} else {
|
||||||
|
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||||
|
float* scale, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||||
|
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||||
|
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < num_tokens; ++i) {
|
||||||
|
cvt_vec_t max_abs(0.0);
|
||||||
|
{
|
||||||
|
int j = 0;
|
||||||
|
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||||
|
load_vec_t elems(input + i * hidden_size + j);
|
||||||
|
cvt_vec_t elems_fp32(elems);
|
||||||
|
max_abs = max_abs.max(elems_fp32.abs());
|
||||||
|
}
|
||||||
|
|
||||||
|
load_vec_t elems(input + i * hidden_size + j);
|
||||||
|
cvt_vec_t elems_fp32(elems);
|
||||||
|
|
||||||
|
if (j + vec_elem_num == hidden_size) {
|
||||||
|
max_abs = max_abs.max(elems_fp32.abs());
|
||||||
|
} else {
|
||||||
|
max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale_val = max_abs.reduce_max() / 127.0f;
|
||||||
|
scale[i] = scale_val;
|
||||||
|
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||||
|
|
||||||
|
{
|
||||||
|
int j = 0;
|
||||||
|
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||||
|
load_vec_t elems(input + i * hidden_size + j);
|
||||||
|
cvt_vec_t elems_fp32(elems);
|
||||||
|
elems_fp32 = (elems_fp32 * inv_scale);
|
||||||
|
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||||
|
elems_int8.save(output + i * hidden_size + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_vec_t elems(input + i * hidden_size + j);
|
||||||
|
cvt_vec_t elems_fp32(elems);
|
||||||
|
elems_fp32 = (elems_fp32 * inv_scale);
|
||||||
|
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||||
|
|
||||||
|
if (j + vec_elem_num == hidden_size) {
|
||||||
|
elems_int8.save(output + i * hidden_size + j);
|
||||||
|
} else {
|
||||||
|
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool Bias, typename scalar_t>
|
||||||
|
void dynamic_output_scale_impl(const float* input, scalar_t* output,
|
||||||
|
const float* scale, const scalar_t* bias,
|
||||||
|
const int num_tokens, const int hidden_size) {
|
||||||
|
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||||
|
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||||
|
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||||
|
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < num_tokens; ++i) {
|
||||||
|
int j = 0;
|
||||||
|
cvt_vec_t token_scale_vec(scale[i]);
|
||||||
|
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||||
|
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||||
|
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||||
|
|
||||||
|
if constexpr (Bias) {
|
||||||
|
load_vec_t bias_vec(bias + j);
|
||||||
|
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||||
|
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||||
|
}
|
||||||
|
|
||||||
|
load_vec_t elems_out(elems_fp32);
|
||||||
|
elems_out.save(output + i * hidden_size + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||||
|
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||||
|
|
||||||
|
if constexpr (Bias) {
|
||||||
|
load_vec_t bias_vec(bias + j);
|
||||||
|
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||||
|
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||||
|
}
|
||||||
|
|
||||||
|
load_vec_t elems_out(elems_fp32);
|
||||||
|
|
||||||
|
if (j + vec_elem_num == hidden_size) {
|
||||||
|
elems_out.save(output + i * hidden_size + j);
|
||||||
|
} else {
|
||||||
|
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <typename scalar_t>
|
||||||
|
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||||
|
const float* scale, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||||
|
float* scale, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void dynamic_output_scale_impl() {
|
||||||
|
TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.")
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||||
|
const torch::Tensor& a, // [M, IC], row-major
|
||||||
|
const torch::Tensor& b, // [IC, OC], column-major
|
||||||
|
const torch::Tensor& a_scales, // [1] or [M]
|
||||||
|
const torch::Tensor& b_scales, // [1] or [OC]
|
||||||
|
const c10::optional<torch::Tensor>& bias // [OC]
|
||||||
|
) {
|
||||||
|
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||||
|
"int8_scaled_mm only supports INT8 inputs.")
|
||||||
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
b.size(1) == c.size(1));
|
||||||
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||||
|
bias->dim() == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] {
|
||||||
|
if (a_scales.numel() != 1) {
|
||||||
|
// per-token
|
||||||
|
// Note: oneDNN doesn't support per-token activation quantization
|
||||||
|
torch::Tensor tmp_fp32_out =
|
||||||
|
torch::empty_like(c, ::at::ScalarType::Float);
|
||||||
|
DNNLPrimitiveHelper<true>::gemm_s8s8_jit(
|
||||||
|
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||||
|
tmp_fp32_out.data_ptr<float>(), (void*)(0), a.size(0), b.size(1),
|
||||||
|
a.size(1), (float*)(0), b_scales.data_ptr<float>(), 0,
|
||||||
|
b_scales.numel());
|
||||||
|
if (bias.has_value()) {
|
||||||
|
dynamic_output_scale_impl<true>(
|
||||||
|
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||||
|
a_scales.data_ptr<float>(), bias->data_ptr<scalar_t>(), c.size(0),
|
||||||
|
c.size(1));
|
||||||
|
} else {
|
||||||
|
dynamic_output_scale_impl<false>(
|
||||||
|
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||||
|
a_scales.data_ptr<float>(), (scalar_t*)(0), c.size(0), c.size(1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// per-tensor
|
||||||
|
if (bias.has_value()) {
|
||||||
|
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||||
|
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||||
|
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
|
||||||
|
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||||
|
a_scales.numel(), b_scales.numel());
|
||||||
|
} else {
|
||||||
|
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||||
|
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||||
|
(void*)(0), a.size(0), b.size(1), a.size(1),
|
||||||
|
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||||
|
a_scales.numel(), b_scales.numel());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// static-per-tensor quantization.
|
||||||
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
|
const torch::Tensor& input, // [..., hidden_size]
|
||||||
|
const torch::Tensor& scale) {
|
||||||
|
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
TORCH_CHECK(scale.numel() == 1);
|
||||||
|
|
||||||
|
const int hidden_size = input.size(-1);
|
||||||
|
const int num_tokens = input.numel() / hidden_size;
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||||
|
static_scaled_int8_quant_impl(
|
||||||
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||||
|
scale.data_ptr<float>(), num_tokens, hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// dynamic-per-token quantization.
|
||||||
|
void dynamic_scaled_int8_quant(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
const torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& scale // [..., 1]
|
||||||
|
) {
|
||||||
|
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
|
||||||
|
int const hidden_size = input.size(-1);
|
||||||
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||||
|
dynamic_scaled_int8_quant_impl(
|
||||||
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||||
|
scale.data_ptr<float>(), num_tokens, hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -4,7 +4,12 @@
|
|||||||
|
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
void init_cpu_threads_env(const std::string& cpu_ids);
|
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||||
|
|
||||||
|
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||||
|
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||||
|
const torch::Tensor& b_scales,
|
||||||
|
const c10::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
// vLLM custom ops
|
// vLLM custom ops
|
||||||
@@ -27,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// PagedAttention V2.
|
// PagedAttention V2.
|
||||||
ops.def(
|
ops.def(
|
||||||
"paged_attention_v2("
|
"paged_attention_v2("
|
||||||
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
||||||
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
@@ -84,6 +89,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor! key, int head_size,"
|
" Tensor! key, int head_size,"
|
||||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||||
|
|
||||||
|
// Quantization
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
// Compute int8 quantized tensor for given scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||||
|
// Compute int8 quantized tensor and scaling factor
|
||||||
|
ops.def(
|
||||||
|
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||||
|
&dynamic_scaled_int8_quant);
|
||||||
|
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
|
// quantization.
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||||
|
" Tensor b, Tensor a_scales,"
|
||||||
|
" Tensor b_scales, Tensor? bias) -> ()");
|
||||||
|
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||||
@@ -95,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
|
|
||||||
// Copy the cache blocks from src to dst.
|
// Copy the cache blocks from src to dst.
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||||
"block_mapping) -> ()");
|
"Tensor block_mapping) -> ()");
|
||||||
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||||
|
|
||||||
// Reshape the key and value tensors and cache them.
|
// Reshape the key and value tensors and cache them.
|
||||||
@@ -111,7 +138,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||||
// CPU utils
|
// CPU utils
|
||||||
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
|
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include "cpu_types.hpp"
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
void init_cpu_threads_env(const std::string& cpu_ids) {
|
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||||
std::vector<int> omp_cpu_ids;
|
std::vector<int> omp_cpu_ids;
|
||||||
@@ -51,15 +51,40 @@ void init_cpu_threads_env(const std::string& cpu_ids) {
|
|||||||
torch::set_num_threads((int)omp_cpu_ids.size());
|
torch::set_num_threads((int)omp_cpu_ids.size());
|
||||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
|
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
|
||||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
||||||
|
|
||||||
|
std::vector<std::pair<int, int>> thread_core_mapping;
|
||||||
|
thread_core_mapping.reserve(omp_cpu_ids.size());
|
||||||
|
omp_lock_t writelock;
|
||||||
|
omp_init_lock(&writelock);
|
||||||
|
|
||||||
#pragma omp parallel for schedule(static, 1)
|
#pragma omp parallel for schedule(static, 1)
|
||||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||||
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
|
cpu_set_t mask;
|
||||||
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
|
CPU_ZERO(&mask);
|
||||||
CPU_ZERO_S(size, mask);
|
CPU_SET(omp_cpu_ids[i], &mask);
|
||||||
CPU_SET_S(omp_cpu_ids[i], size, mask);
|
int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
|
||||||
sched_setaffinity(0, sizeof(cpu_set_t), mask);
|
if (ret == -1) {
|
||||||
CPU_FREE(mask);
|
TORCH_CHECK(false,
|
||||||
|
"sched_setaffinity failed. errno: " + std::to_string(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
omp_set_lock(&writelock);
|
||||||
|
thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
|
||||||
|
omp_unset_lock(&writelock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
omp_destroy_lock(&writelock);
|
||||||
|
|
||||||
numa_free_nodemask(omp_cpu_mask);
|
numa_free_nodemask(omp_cpu_mask);
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "OMP threads binding of Process " << getpid() << ":\n";
|
||||||
|
std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
|
||||||
|
[](auto&& a, auto&& b) { return a.second < b.second; });
|
||||||
|
for (auto&& item : thread_core_mapping) {
|
||||||
|
ss << "\t"
|
||||||
|
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,15 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||||
|
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||||
|
#define DEVICE_INLINE __forceinline__ __device__
|
||||||
|
#define HOST_INLINE __forceinline__ __host__
|
||||||
|
#else
|
||||||
|
#define HOST_DEVICE_INLINE inline
|
||||||
|
#define DEVICE_INLINE inline
|
||||||
|
#define HOST_INLINE inline
|
||||||
|
#endif
|
||||||
|
|
||||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||||
|
|
||||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||||
|
|||||||
68
csrc/cutlass_extensions/cute_utils.cuh
Normal file
68
csrc/cutlass_extensions/cute_utils.cuh
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
#include <torch/all.h>
|
||||||
|
namespace cute {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
// layout utils
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Permute layout based on indices, example:
|
||||||
|
// permute_layout<1, 0>(layout) will swap the two dimensions
|
||||||
|
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
|
||||||
|
template <size_t... I, typename Layout>
|
||||||
|
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
||||||
|
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
|
||||||
|
return cute::make_layout(cute::get<I>(l)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// is the layout f(x) = x
|
||||||
|
template <typename Layout>
|
||||||
|
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||||
|
if constexpr (std::is_same_v<Layout, void>)
|
||||||
|
return true;
|
||||||
|
else {
|
||||||
|
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||||
|
if constexpr (rank(coalesced_layout) == 1 &&
|
||||||
|
stride<0>(coalesced_layout) == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
// Pointer utils
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <class PointerType>
|
||||||
|
static constexpr auto get_logical_ptr(PointerType* ptr) {
|
||||||
|
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
|
||||||
|
return cute::subbyte_iterator<PointerType>(ptr);
|
||||||
|
} else {
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
// Misc utils
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename Elements>
|
||||||
|
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
|
||||||
|
constexpr auto bits = sizeof_bits_v<T> * Elements{};
|
||||||
|
if constexpr (bits % 128 == 0) {
|
||||||
|
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||||
|
} else if constexpr (bits % 64 == 0) {
|
||||||
|
return AutoVectorizingCopyWithAssumedAlignment<64>{};
|
||||||
|
} else if constexpr (bits % 32 == 0) {
|
||||||
|
return AutoVectorizingCopyWithAssumedAlignment<32>{};
|
||||||
|
} else if constexpr (bits % 16 == 0) {
|
||||||
|
return AutoVectorizingCopyWithAssumedAlignment<16>{};
|
||||||
|
} else {
|
||||||
|
return AutoVectorizingCopyWithAssumedAlignment<8>{};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace cute
|
||||||
154
csrc/cutlass_extensions/torch_utils.hpp
Normal file
154
csrc/cutlass_extensions/torch_utils.hpp
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cute/layout.hpp"
|
||||||
|
#include "cutlass/layout/matrix.h"
|
||||||
|
#include "cutlass/bfloat16.h"
|
||||||
|
#include "cutlass/half.h"
|
||||||
|
|
||||||
|
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||||
|
using RowMajor = typename cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
namespace cute {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
template <class T, class F, class G, int... I>
|
||||||
|
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
|
||||||
|
seq<I...>) {
|
||||||
|
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class F, int... I>
|
||||||
|
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
|
||||||
|
return make_shape(f(I)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace detail
|
||||||
|
|
||||||
|
template <class T, class F>
|
||||||
|
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
|
||||||
|
if constexpr (cute::is_tuple<T>::value) {
|
||||||
|
return detail::tapply_with_idx(
|
||||||
|
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
|
||||||
|
tuple_seq<T>{});
|
||||||
|
} else {
|
||||||
|
return f(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTE_GCC_UNREACHABLE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// calls: make_shape(f(0), f(1), ..., f(N-1))
|
||||||
|
template <int N, class F>
|
||||||
|
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||||
|
return detail::make_shape_from_idx(f, make_seq<N>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace cute
|
||||||
|
|
||||||
|
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
|
||||||
|
// shape of the passed in tensor and the strides are of type `Stride` and
|
||||||
|
// contain the strides of the passed in tensor, checking that any static strides
|
||||||
|
// in `Stride{}` match the strides of the passed in tensor.
|
||||||
|
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||||
|
// strides are set to be 0 or 1.
|
||||||
|
template <typename Stride>
|
||||||
|
static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||||
|
std::string_view name = "tensor") {
|
||||||
|
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
|
||||||
|
auto stride = cute::transform_with_idx(
|
||||||
|
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||||
|
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||||
|
|
||||||
|
if (idx < tensor.dim()) {
|
||||||
|
if constexpr (cute::is_static_v<StrideEle>) {
|
||||||
|
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||||
|
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||||
|
return StrideEle{};
|
||||||
|
} else {
|
||||||
|
return tensor.stride(idx);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Extra strides are assumed to be 0 or 1
|
||||||
|
if constexpr (cute::is_static_v<StrideEle>) {
|
||||||
|
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||||
|
}
|
||||||
|
return StrideEle{};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||||
|
if (idx < tensor.dim())
|
||||||
|
return tensor.size(idx);
|
||||||
|
else
|
||||||
|
return int64_t(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
return make_layout(shape, stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Stride>
|
||||||
|
static inline auto maybe_make_cute_layout(
|
||||||
|
c10::optional<torch::Tensor> const& tensor,
|
||||||
|
std::string_view name = "tensor") {
|
||||||
|
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||||
|
|
||||||
|
if (tensor) {
|
||||||
|
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
|
||||||
|
} else {
|
||||||
|
return std::optional<Layout>{};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Torch Type to Cutlass Type (equivalent_cutlass_type)
|
||||||
|
//
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct equivalent_cutlass_type {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct equivalent_cutlass_type<c10::Half> {
|
||||||
|
using type = cutlass::half_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct equivalent_cutlass_type<c10::BFloat16> {
|
||||||
|
using type = cutlass::bfloat16_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||||
|
//
|
||||||
|
|
||||||
|
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||||
|
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||||
|
template <typename T>
|
||||||
|
struct equivalent_scalar_type {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct equivalent_scalar_type<cutlass::half_t> {
|
||||||
|
using type = c10::Half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||||
|
using type = c10::BFloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
// get equivalent c10::ScalarType tag from compile time type
|
||||||
|
template <typename T>
|
||||||
|
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
|
||||||
|
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
||||||
43
csrc/cutlass_extensions/vllm_collective_builder.cuh
Normal file
43
csrc/cutlass_extensions/vllm_collective_builder.cuh
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
//
|
||||||
|
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
|
||||||
|
// for custom kernel tags, allowing you to build custom collectives. Without
|
||||||
|
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
|
||||||
|
// will resort to using the standard cutlass collective builder.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
|
||||||
|
// collective
|
||||||
|
struct CutlassKernelTag {};
|
||||||
|
|
||||||
|
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
|
||||||
|
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
|
||||||
|
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
|
||||||
|
class ClusterShape_MNK, class StageCountType,
|
||||||
|
class KernelScheduleType, class Enable = void>
|
||||||
|
struct VLLMCollectiveBuilder {
|
||||||
|
static_assert(sizeof(ElementA) == 0,
|
||||||
|
"Could not build a collective for given parameters.");
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||||
|
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||||
|
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||||
|
class StageCountType, class KernelScheduleType>
|
||||||
|
struct VLLMCollectiveBuilder<
|
||||||
|
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
|
||||||
|
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||||
|
ClusterShape_MNK, StageCountType, KernelScheduleType> {
|
||||||
|
using CollectiveOp = typename CollectiveBuilder<
|
||||||
|
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||||
|
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||||
|
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
}; // namespace cutlass::gemm::collective
|
||||||
50
csrc/cutlass_extensions/vllm_custom_types.cuh
Normal file
50
csrc/cutlass_extensions/vllm_custom_types.cuh
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/integer_subbyte.h"
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int Bits, int Bias, bool Signed = false>
|
||||||
|
struct vllm_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
|
||||||
|
using Base = integer_subbyte<Bits, Signed>;
|
||||||
|
|
||||||
|
using Storage = typename Base::Storage;
|
||||||
|
using xint_t = typename Base::xint_t;
|
||||||
|
|
||||||
|
using Base::bits_mask_;
|
||||||
|
using Base::sign_mask_;
|
||||||
|
using Base::storage;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// No operation
|
||||||
|
vllm_biased_integer_subbyte() = default;
|
||||||
|
|
||||||
|
/// Conversion from integer type
|
||||||
|
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value)
|
||||||
|
: Base(value) {}
|
||||||
|
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value)
|
||||||
|
: Base(value) {}
|
||||||
|
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value)
|
||||||
|
: Base(value) {}
|
||||||
|
};
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// "GPTQ" types, i.e. symmetric quantization
|
||||||
|
using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8
|
||||||
|
using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int Bits, int Bias, bool Signed>
|
||||||
|
struct sizeof_bits<vllm_biased_integer_subbyte<Bits, Bias, Signed>> {
|
||||||
|
static constexpr int value = Bits;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass
|
||||||
49
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
Normal file
49
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import enum
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
from cutlass_library import *
|
||||||
|
|
||||||
|
#
|
||||||
|
# Extend cutlass library with custom types, and missing values
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMDataType(enum.Enum):
|
||||||
|
u4b8 = enum_auto()
|
||||||
|
u8b128 = enum_auto()
|
||||||
|
|
||||||
|
|
||||||
|
class MixedInputKernelScheduleType(enum.Enum):
|
||||||
|
TmaWarpSpecializedMixedInput = enum_auto()
|
||||||
|
TmaWarpSpecializedPingpongMixedInput = enum_auto()
|
||||||
|
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
|
||||||
|
|
||||||
|
|
||||||
|
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||||
|
**DataTypeNames, # type: ignore
|
||||||
|
**{
|
||||||
|
VLLMDataType.u4b8: "u4b8",
|
||||||
|
VLLMDataType.u8b128: "u8b128",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||||
|
**DataTypeTag, # type: ignore
|
||||||
|
**{
|
||||||
|
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||||
|
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VLLMKernelScheduleTag: Dict[Union[
|
||||||
|
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||||
|
**KernelScheduleTag, # type: ignore
|
||||||
|
**{
|
||||||
|
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
|
||||||
|
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
|
||||||
|
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
|
||||||
|
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
|
||||||
|
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
|
||||||
|
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
|
||||||
|
}
|
||||||
|
}
|
||||||
795
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Normal file
795
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Normal file
@@ -0,0 +1,795 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
#include "cutlass_extensions/vllm_custom_types.cuh"
|
||||||
|
#include "cutlass_extensions/cute_utils.cuh"
|
||||||
|
|
||||||
|
// this file extends:
|
||||||
|
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||||
|
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
|
||||||
|
// as well as adds interleaved numeric array converters for specific types.
|
||||||
|
// (interleaved numeric array converters can be more efficient for subbyte
|
||||||
|
// types)
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
|
||||||
|
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
|
||||||
|
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
|
||||||
|
// make subbyte converts more efficient by allowing for efficient extraction
|
||||||
|
// of subbyte elements from a 32bit register.
|
||||||
|
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||||
|
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||||
|
class Enable = void>
|
||||||
|
struct InterleavedNumericArrayConverter {
|
||||||
|
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||||
|
|
||||||
|
using result_type = typename Converter::result_type;
|
||||||
|
using source_type = typename Converter::source_type;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
CUTE_INVALID_CONTROL_PATH(
|
||||||
|
"InterleavedNumericArrayConverter not implemented\n");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||||
|
FloatRoundStyle Round>
|
||||||
|
struct InterleavedNumericArrayConverter<
|
||||||
|
IlvBlkLayout, T, S, N, Round,
|
||||||
|
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
|
||||||
|
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||||
|
|
||||||
|
using result_type = typename Converter::result_type;
|
||||||
|
using source_type = typename Converter::source_type;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return Converter::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO (LucasWilkinson): Implement
|
||||||
|
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
|
||||||
|
// ....
|
||||||
|
|
||||||
|
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||||
|
struct ArrayConverterPacked32Bit {
|
||||||
|
using result_type = Array<T, N>;
|
||||||
|
using source_type = Array<S, N>;
|
||||||
|
|
||||||
|
using result_packed_8_t = Array<T, 8>;
|
||||||
|
using result_packed_4_t = Array<T, 4>;
|
||||||
|
using result_packed_2_t = Array<T, 2>;
|
||||||
|
using src_packed_8_t = Array<S, 8>;
|
||||||
|
using src_packed_4_t = Array<S, 4>;
|
||||||
|
using src_packed_2_t = Array<S, 2>;
|
||||||
|
|
||||||
|
static_assert(N % 2 == 0, "N must be a multiple of 2");
|
||||||
|
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
|
||||||
|
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
|
||||||
|
static constexpr auto src_elems_per_32bit_reg =
|
||||||
|
32 / cutlass::sizeof_bits_v<S>;
|
||||||
|
|
||||||
|
// Maybe not Valid. ScalarConverter will not actually work unless
|
||||||
|
// NumericConverter<T, S, Round> is implemented. However it won't be used
|
||||||
|
// anyways since we assert N % 2 == 0, just here for compliance with
|
||||||
|
// VectorizedConverter.
|
||||||
|
using ScalarConverter = NumericConverter<T, S>;
|
||||||
|
|
||||||
|
template <typename PackedSrc>
|
||||||
|
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
|
||||||
|
if constexpr (sizeof(PackedSrc) == 1) {
|
||||||
|
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
|
||||||
|
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||||
|
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
|
||||||
|
} else {
|
||||||
|
static_assert(sizeof(PackedSrc) == 4);
|
||||||
|
return reinterpret_cast<const uint32_t&>(source);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The core converter uses bit tricks to construct a known FP16 number, then
|
||||||
|
// does a subtraction in FP16 for the final result.
|
||||||
|
template <typename PackedResultType, typename PackedSrcType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||||
|
PackedSrcType const& source) {
|
||||||
|
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
|
||||||
|
static_assert(PackedResultType::kElements == 2 ||
|
||||||
|
PackedResultType::kElements == 4 ||
|
||||||
|
PackedResultType::kElements == 8,
|
||||||
|
"Invalid PackedResultType must be 2, 4 or 8.");
|
||||||
|
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||||
|
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||||
|
|
||||||
|
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
|
||||||
|
}
|
||||||
|
|
||||||
|
friend class detail::VectorizedConverter;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE static result_type convert(source_type const& source) {
|
||||||
|
result_type result;
|
||||||
|
using ConverterType =
|
||||||
|
ArrayConverterPacked32Bit<RegConvert32bit,
|
||||||
|
typename result_type::Element,
|
||||||
|
typename source_type::Element, N>;
|
||||||
|
|
||||||
|
if constexpr (src_elems_per_32bit_reg >= 8) {
|
||||||
|
detail::VectorizedConverter::convert<
|
||||||
|
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
|
||||||
|
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
|
||||||
|
} else if constexpr (src_elems_per_32bit_reg >= 4) {
|
||||||
|
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||||
|
src_packed_4_t, result_packed_2_t,
|
||||||
|
src_packed_2_t>(result, source);
|
||||||
|
} else {
|
||||||
|
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
|
||||||
|
src_packed_2_t>(result, source);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
||||||
|
using result_type = Array<cutlass::half_t, N>;
|
||||||
|
using source_type = Array<vllm_uint4b8_t, N>;
|
||||||
|
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
// Below constructs the following temporary:
|
||||||
|
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
|
||||||
|
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
|
||||||
|
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
|
||||||
|
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
|
||||||
|
// We use inline asm instead of __byte_perm intrinsic since we don't want
|
||||||
|
// the documented (& 0x7) on the index. NVCC might be able to optimize it
|
||||||
|
// out since the index is a constexpr, but we choose to be safe about it
|
||||||
|
// here.
|
||||||
|
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||||
|
static_assert(RegArray::kElements <= 4,
|
||||||
|
"Too many inputs for F16 -> I4 vector converter");
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" prmt.b32 %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(r[ii])
|
||||||
|
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||||
|
// we are trying to construct x and a fp16 value
|
||||||
|
// The below XOR does the following:
|
||||||
|
// 1) Sets the exponent bits of the FP16 to the correct value for the
|
||||||
|
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
|
||||||
|
// where x1 in the high nibble and x0 is the low nibble then using hfma
|
||||||
|
// to subtract 1032 from that
|
||||||
|
// The AND does the following:
|
||||||
|
// 1) Clear the set bits for the int4 we will ignore.
|
||||||
|
// We use lop3 so that we can use 1 instruction for AND and XOR.
|
||||||
|
static constexpr uint32_t xor_mask = 0x64006400;
|
||||||
|
static constexpr uint32_t and_mask = 0xFFF0FF0F;
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||||
|
|
||||||
|
// For each operand, computes:
|
||||||
|
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii])
|
||||||
|
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||||
|
}
|
||||||
|
|
||||||
|
// We will issue 2 hfmas that do the following:
|
||||||
|
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
|
||||||
|
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
|
||||||
|
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
|
||||||
|
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
|
||||||
|
|
||||||
|
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
|
||||||
|
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||||
|
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
// for IlvdLayout: (2, 4):(4, 1)
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||||
|
cutlass::half_t, vllm_uint4b8_t, N,
|
||||||
|
Round, void> {
|
||||||
|
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||||
|
static_assert(N % size(IlvdLayout{}) == 0);
|
||||||
|
|
||||||
|
using result_type = Array<cutlass::half_t, N>;
|
||||||
|
using source_type = Array<vllm_uint4b8_t, N>;
|
||||||
|
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||||
|
static constexpr uint32_t xor_mask = 0x64006400;
|
||||||
|
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||||
|
auto src_ = src >> (4 * (ii));
|
||||||
|
r[ii + 0] = src_;
|
||||||
|
r[ii + 1] = src_;
|
||||||
|
|
||||||
|
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||||
|
|
||||||
|
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||||
|
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii + 0])
|
||||||
|
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii + 1])
|
||||||
|
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||||
|
|
||||||
|
// For low nibble:
|
||||||
|
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
|
||||||
|
// For high nibble:
|
||||||
|
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
|
||||||
|
// - {72, 72}
|
||||||
|
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
|
||||||
|
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||||
|
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
|
||||||
|
|
||||||
|
{
|
||||||
|
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||||
|
fp16x2_val =
|
||||||
|
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||||
|
fp16x2_val = __hfma2(fp16x2_val,
|
||||||
|
reinterpret_cast<const half2&>(high_nib_scale),
|
||||||
|
reinterpret_cast<const half2&>(high_nib_bias));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
|
||||||
|
// for IlvdLayout: (2, 4):(4, 1)
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||||
|
cutlass::half_t, uint4_t, N, Round,
|
||||||
|
void> {
|
||||||
|
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||||
|
static_assert(N % size(IlvdLayout{}) == 0);
|
||||||
|
|
||||||
|
using result_type = Array<cutlass::half_t, N>;
|
||||||
|
using source_type = Array<uint4_t, N>;
|
||||||
|
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||||
|
static constexpr uint32_t xor_mask = 0x64006400;
|
||||||
|
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||||
|
auto src_ = src >> (4 * (ii));
|
||||||
|
r[ii + 0] = src_;
|
||||||
|
r[ii + 1] = src_;
|
||||||
|
|
||||||
|
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||||
|
|
||||||
|
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||||
|
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii + 0])
|
||||||
|
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii + 1])
|
||||||
|
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||||
|
|
||||||
|
// For low nibble:
|
||||||
|
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
|
||||||
|
// For high nibble:
|
||||||
|
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
|
||||||
|
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
|
||||||
|
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||||
|
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
|
||||||
|
|
||||||
|
{
|
||||||
|
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||||
|
fp16x2_val =
|
||||||
|
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||||
|
fp16x2_val = __hfma2(fp16x2_val,
|
||||||
|
reinterpret_cast<const half2&>(high_nib_scale),
|
||||||
|
reinterpret_cast<const half2&>(high_nib_bias));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::half_t, N> <= Array<vllm_uint8b128_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
|
||||||
|
using result_type = Array<cutlass::half_t, N>;
|
||||||
|
using source_type = Array<vllm_uint8b128_t, N>;
|
||||||
|
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||||
|
: "=r"(r[ii])
|
||||||
|
: "r"(src), "n"(start_byte_for_fp16),
|
||||||
|
"r"(prmt_indices[ii]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
|
||||||
|
static constexpr uint32_t bias_rep = 0x64806480;
|
||||||
|
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||||
|
fp16x2_val = __hsub2(fp16x2_val, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::float, N> <= Array<vllm_uint8b128_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
|
||||||
|
using result_type = Array<float, N>;
|
||||||
|
using source_type = Array<vllm_uint8b128_t, N>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
PackedResultType r;
|
||||||
|
|
||||||
|
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||||
|
// u8x4 source and stores the result in r (without introducing extra
|
||||||
|
// cvt.u32.u8 instruction)
|
||||||
|
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
|
||||||
|
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
|
||||||
|
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
|
||||||
|
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
|
||||||
|
// Subtract the magic number 0x4B000000 from tmp in floating-point
|
||||||
|
// arithmetic to obtain final result
|
||||||
|
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
|
||||||
|
}
|
||||||
|
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||||
|
|
||||||
|
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
|
||||||
|
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||||
|
using source_type = Array<vllm_uint4b8_t, N>;
|
||||||
|
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
|
||||||
|
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
uint32_t src_reg_shifted = src_reg >> 4;
|
||||||
|
|
||||||
|
// Below constructs the following temporary:
|
||||||
|
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||||
|
static_assert(RegArray::kElements <= 4,
|
||||||
|
"Too many inputs for uint4b8_t -> BF16 vector converter");
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" prmt.b32 %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(r[ii])
|
||||||
|
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||||
|
// we are trying to construct x and a BF16 value
|
||||||
|
// The below XOR does the following:
|
||||||
|
// 1) Sets the exponent bits of the BF16 to the correct value for the
|
||||||
|
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
|
||||||
|
// and subtracting 136 to get {x1, x0}
|
||||||
|
static constexpr uint32_t xor_mask = 0x43004300;
|
||||||
|
static constexpr uint32_t and_mask = 0x000F000F;
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||||
|
|
||||||
|
// For each operand, computes:
|
||||||
|
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii])
|
||||||
|
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||||
|
}
|
||||||
|
|
||||||
|
// We will issue 2 bfmas that do the following:
|
||||||
|
// high BF16:
|
||||||
|
// hi_bf16 - 136, lo_bf16 - 136
|
||||||
|
|
||||||
|
// This is the BF16 {136, 136} represented as an integer.
|
||||||
|
static constexpr uint32_t bias_rep = 0x43084308;
|
||||||
|
const __nv_bfloat162& bias =
|
||||||
|
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||||
|
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
|
||||||
|
// for IlvdLayout: (2, 4):(4, 1)
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||||
|
cutlass::bfloat16_t, vllm_uint4b8_t, N,
|
||||||
|
Round, void> {
|
||||||
|
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||||
|
static_assert(N % size(IlvdLayout{}) == 0);
|
||||||
|
|
||||||
|
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||||
|
using source_type = Array<vllm_uint4b8_t, N>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||||
|
static constexpr uint32_t or_mask = 0x43004300;
|
||||||
|
|
||||||
|
// Unlike float16 where the mantissa is large enough to contain 2
|
||||||
|
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||||
|
// nibble at a time
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
r[ii] = src >> (4 * ii);
|
||||||
|
|
||||||
|
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii + 0])
|
||||||
|
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||||
|
|
||||||
|
// For low nibble:
|
||||||
|
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
|
||||||
|
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
|
||||||
|
|
||||||
|
{
|
||||||
|
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||||
|
fp16x2_val =
|
||||||
|
__hsub2(fp16x2_val,
|
||||||
|
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
|
||||||
|
// for IlvdLayout: (2, 4):(4, 1)
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||||
|
cutlass::bfloat16_t, uint4_t, N, Round,
|
||||||
|
void> {
|
||||||
|
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||||
|
static_assert(N % size(IlvdLayout{}) == 0);
|
||||||
|
|
||||||
|
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||||
|
using source_type = Array<uint4_t, N>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RegConvert {
|
||||||
|
template <typename PackedResultType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||||
|
using RegArray =
|
||||||
|
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||||
|
sizeof(PackedResultType)>;
|
||||||
|
RegArray r;
|
||||||
|
|
||||||
|
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||||
|
static constexpr uint32_t or_mask = 0x43004300;
|
||||||
|
|
||||||
|
// Unlike float16 where the mantissa is large enough to contain 2
|
||||||
|
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||||
|
// nibble at a time
|
||||||
|
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||||
|
r[ii] = src >> (4 * ii);
|
||||||
|
|
||||||
|
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||||
|
"}\n"
|
||||||
|
: "+r"(r[ii])
|
||||||
|
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||||
|
|
||||||
|
// For low nibble:
|
||||||
|
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
|
||||||
|
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
|
||||||
|
|
||||||
|
{
|
||||||
|
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||||
|
fp16x2_val =
|
||||||
|
__hsub2(fp16x2_val,
|
||||||
|
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<PackedResultType&>(r);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||||
|
typename source_type::Element,
|
||||||
|
N>::convert(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint8b128_t, N>
|
||||||
|
template <FloatRoundStyle Round, int N>
|
||||||
|
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
|
||||||
|
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||||
|
using source_type = Array<vllm_uint8b128_t, N>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
private:
|
||||||
|
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
|
||||||
|
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
|
||||||
|
using src_packed_4_t = Array<vllm_uint8b128_t, 4>;
|
||||||
|
using src_packed_2_t = Array<vllm_uint8b128_t, 2>;
|
||||||
|
|
||||||
|
// Not Valid, not supported, only here to satisfy the interface and to avoid
|
||||||
|
// a compile error. ScalarConverter will not actually work until
|
||||||
|
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
|
||||||
|
// implemented
|
||||||
|
using ScalarConverter =
|
||||||
|
NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round>;
|
||||||
|
|
||||||
|
template <typename PackedResultType, typename PackedSrcType>
|
||||||
|
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||||
|
PackedSrcType const& source) {
|
||||||
|
static_assert(
|
||||||
|
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
|
||||||
|
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
|
||||||
|
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
|
||||||
|
platform::is_same<PackedResultType, result_packed_4_t>::value),
|
||||||
|
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
|
||||||
|
"convert dispatch.");
|
||||||
|
|
||||||
|
NumericArrayConverter<float, vllm_uint8b128_t, PackedResultType::kElements,
|
||||||
|
Round>
|
||||||
|
convert_uint8_to_f32;
|
||||||
|
Array<float, PackedResultType::kElements> tmp =
|
||||||
|
convert_uint8_to_f32(source);
|
||||||
|
NumericArrayConverter<cutlass::bfloat16_t, float,
|
||||||
|
PackedResultType::kElements, Round>
|
||||||
|
convert_f32_to_bf16_;
|
||||||
|
return convert_f32_to_bf16_(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend class detail::VectorizedConverter;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const& source) {
|
||||||
|
result_type result;
|
||||||
|
using ConverterType =
|
||||||
|
NumericArrayConverter<typename result_type::Element,
|
||||||
|
typename source_type::Element, N, Round>;
|
||||||
|
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||||
|
src_packed_4_t, result_packed_2_t,
|
||||||
|
src_packed_2_t>(result, source);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const& s) const { return convert(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -3,13 +3,16 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "reduction_utils.cuh"
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <cub/util_type.cuh>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
#else
|
#else
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#include <hip/hip_fp16.h>
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <hipcub/util_type.hpp>
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
@@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
|
|||||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
|
|||||||
variance += temp.sum_squares();
|
variance += temp.sum_squares();
|
||||||
residual_v[id] = temp;
|
residual_v[id] = temp;
|
||||||
}
|
}
|
||||||
/* Keep the following if-else block in sync with the
|
|
||||||
calculation of max_block_size in fused_add_rms_norm */
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
if (num_tokens < 256) {
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||||
} else
|
|
||||||
variance = blockReduceSum<float, 256>(variance);
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
|
|||||||
variance += x * x;
|
variance += x * x;
|
||||||
residual[blockIdx.x * hidden_size + idx] = z;
|
residual[blockIdx.x * hidden_size + idx] = z;
|
||||||
}
|
}
|
||||||
/* Keep the following if-else block in sync with the
|
|
||||||
calculation of max_block_size in fused_add_rms_norm */
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
if (num_tokens < 256) {
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||||
} else
|
|
||||||
variance = blockReduceSum<float, 256>(variance);
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
|
|||||||
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
@@ -0,0 +1,700 @@
|
|||||||
|
// clang-format off
|
||||||
|
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||||
|
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "causal_conv1d.h"
|
||||||
|
#include <c10/util/BFloat16.h>
|
||||||
|
#include <c10/util/Half.h>
|
||||||
|
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||||
|
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_store.cuh>
|
||||||
|
|
||||||
|
#include "static_switch.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||||
|
|
||||||
|
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||||
|
if (ITYPE == at::ScalarType::Half) { \
|
||||||
|
using input_t = at::Half; \
|
||||||
|
using weight_t = at::Half; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||||
|
using input_t = at::BFloat16; \
|
||||||
|
using weight_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else if (ITYPE == at::ScalarType::Float) { \
|
||||||
|
using input_t = float; \
|
||||||
|
using weight_t = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template <typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
|
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||||
|
// sizes
|
||||||
|
const size_t batch,
|
||||||
|
const size_t dim,
|
||||||
|
const size_t seqlen,
|
||||||
|
const size_t width,
|
||||||
|
// device pointers
|
||||||
|
const at::Tensor x,
|
||||||
|
const at::Tensor weight,
|
||||||
|
const at::Tensor out,
|
||||||
|
void* bias_ptr,
|
||||||
|
bool silu_activation) {
|
||||||
|
|
||||||
|
// Reset the parameters
|
||||||
|
memset(¶ms, 0, sizeof(params));
|
||||||
|
|
||||||
|
params.batch = batch;
|
||||||
|
params.dim = dim;
|
||||||
|
params.seqlen = seqlen;
|
||||||
|
params.width = width;
|
||||||
|
|
||||||
|
params.silu_activation = silu_activation;
|
||||||
|
|
||||||
|
// Set the pointers and strides.
|
||||||
|
params.x_ptr = x.data_ptr();
|
||||||
|
params.weight_ptr = weight.data_ptr();
|
||||||
|
params.bias_ptr = bias_ptr;
|
||||||
|
params.out_ptr = out.data_ptr();
|
||||||
|
// All stride are in elements, not bytes.
|
||||||
|
params.x_batch_stride = x.stride(0);
|
||||||
|
params.x_c_stride = x.stride(1);
|
||||||
|
params.x_l_stride = x.stride(-1);
|
||||||
|
params.weight_c_stride = weight.stride(0);
|
||||||
|
params.weight_width_stride = weight.stride(1);
|
||||||
|
params.out_batch_stride = out.stride(0);
|
||||||
|
params.out_c_stride = out.stride(1);
|
||||||
|
params.out_l_stride = out.stride(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||||
|
const c10::optional<at::Tensor> &bias_,
|
||||||
|
const c10::optional<at::Tensor> &seq_idx_,
|
||||||
|
const c10::optional<at::Tensor> &initial_states_,
|
||||||
|
const c10::optional<at::Tensor> &final_states_out_,
|
||||||
|
bool silu_activation) {
|
||||||
|
auto input_type = x.scalar_type();
|
||||||
|
auto weight_type = weight.scalar_type();
|
||||||
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||||
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||||
|
|
||||||
|
TORCH_CHECK(x.is_cuda());
|
||||||
|
TORCH_CHECK(weight.is_cuda());
|
||||||
|
|
||||||
|
const auto sizes = x.sizes();
|
||||||
|
const int batch_size = sizes[0];
|
||||||
|
const int dim = sizes[1];
|
||||||
|
const int seqlen = sizes[2];
|
||||||
|
const int width = weight.size(-1);
|
||||||
|
|
||||||
|
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||||
|
CHECK_SHAPE(weight, dim, width);
|
||||||
|
|
||||||
|
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
||||||
|
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
||||||
|
|
||||||
|
if (is_channel_last) {
|
||||||
|
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
||||||
|
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
||||||
|
}
|
||||||
|
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||||
|
|
||||||
|
if (bias_.has_value()) {
|
||||||
|
auto bias = bias_.value();
|
||||||
|
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||||
|
TORCH_CHECK(bias.is_cuda());
|
||||||
|
TORCH_CHECK(bias.stride(-1) == 1);
|
||||||
|
CHECK_SHAPE(bias, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq_idx_.has_value()) {
|
||||||
|
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
||||||
|
auto seq_idx = seq_idx_.value();
|
||||||
|
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
||||||
|
TORCH_CHECK(seq_idx.is_cuda());
|
||||||
|
TORCH_CHECK(seq_idx.is_contiguous());
|
||||||
|
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor out = torch::empty_like(x);
|
||||||
|
|
||||||
|
ConvParamsBase params;
|
||||||
|
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||||
|
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||||
|
silu_activation);
|
||||||
|
|
||||||
|
if (seq_idx_.has_value()) {
|
||||||
|
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
||||||
|
} else {
|
||||||
|
params.seq_idx_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (initial_states_.has_value()) {
|
||||||
|
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
||||||
|
auto initial_states = initial_states_.value();
|
||||||
|
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
||||||
|
TORCH_CHECK(initial_states.is_cuda());
|
||||||
|
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
||||||
|
TORCH_CHECK(initial_states.stride(1) == 1);
|
||||||
|
params.initial_states_ptr = initial_states.data_ptr();
|
||||||
|
params.initial_states_batch_stride = initial_states.stride(0);
|
||||||
|
params.initial_states_c_stride = initial_states.stride(1);
|
||||||
|
params.initial_states_l_stride = initial_states.stride(2);
|
||||||
|
} else {
|
||||||
|
params.initial_states_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (final_states_out_.has_value()) {
|
||||||
|
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
||||||
|
auto final_states = final_states_out_.value();
|
||||||
|
TORCH_CHECK(final_states.scalar_type() == input_type);
|
||||||
|
TORCH_CHECK(final_states.is_cuda());
|
||||||
|
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
||||||
|
TORCH_CHECK(final_states.stride(1) == 1);
|
||||||
|
params.final_states_ptr = final_states.data_ptr();
|
||||||
|
params.final_states_batch_stride = final_states.stride(0);
|
||||||
|
params.final_states_c_stride = final_states.stride(1);
|
||||||
|
params.final_states_l_stride = final_states.stride(2);
|
||||||
|
} else {
|
||||||
|
params.final_states_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||||
|
if (!is_channel_last) {
|
||||||
|
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||||
|
} else {
|
||||||
|
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
causal_conv1d_update(const at::Tensor &x,
|
||||||
|
const at::Tensor &conv_state,
|
||||||
|
const at::Tensor &weight,
|
||||||
|
const c10::optional<at::Tensor> &bias_,
|
||||||
|
bool silu_activation) {
|
||||||
|
auto input_type = x.scalar_type();
|
||||||
|
auto weight_type = weight.scalar_type();
|
||||||
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||||
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||||
|
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||||
|
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||||
|
|
||||||
|
TORCH_CHECK(x.is_cuda());
|
||||||
|
TORCH_CHECK(conv_state.is_cuda());
|
||||||
|
TORCH_CHECK(weight.is_cuda());
|
||||||
|
|
||||||
|
const auto sizes = x.sizes();
|
||||||
|
const int batch_size = sizes[0];
|
||||||
|
const int dim = sizes[1];
|
||||||
|
const int width = weight.size(-1);
|
||||||
|
|
||||||
|
CHECK_SHAPE(x, batch_size, dim);
|
||||||
|
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||||
|
CHECK_SHAPE(weight, dim, width);
|
||||||
|
|
||||||
|
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||||
|
|
||||||
|
if (bias_.has_value()) {
|
||||||
|
auto bias = bias_.value();
|
||||||
|
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||||
|
TORCH_CHECK(bias.is_cuda());
|
||||||
|
TORCH_CHECK(bias.stride(-1) == 1);
|
||||||
|
CHECK_SHAPE(bias, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor out = torch::empty_like(x);
|
||||||
|
|
||||||
|
ConvParamsBase params;
|
||||||
|
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
||||||
|
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||||
|
silu_activation);
|
||||||
|
params.conv_state_ptr = conv_state.data_ptr();
|
||||||
|
// All stride are in elements, not bytes.
|
||||||
|
params.conv_state_batch_stride = conv_state.stride(0);
|
||||||
|
params.conv_state_c_stride = conv_state.stride(1);
|
||||||
|
params.conv_state_l_stride = conv_state.stride(2);
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||||
|
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||||
|
});
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||||
|
struct Causal_conv1d_fwd_kernel_traits {
|
||||||
|
using input_t = input_t_;
|
||||||
|
using weight_t = weight_t_;
|
||||||
|
static constexpr int kNThreads = kNThreads_;
|
||||||
|
static constexpr int kWidth = kWidth_;
|
||||||
|
static constexpr int kNBytes = sizeof(input_t);
|
||||||
|
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||||
|
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||||
|
static_assert(kWidth <= kNElts);
|
||||||
|
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||||
|
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||||
|
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||||
|
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||||
|
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||||
|
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||||
|
static constexpr int kSmemIOSize = kIsVecLoad
|
||||||
|
? 0
|
||||||
|
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||||
|
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||||
|
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||||
|
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||||
|
constexpr int kWidth = Ktraits::kWidth;
|
||||||
|
constexpr int kNThreads = Ktraits::kNThreads;
|
||||||
|
constexpr int kNElts = Ktraits::kNElts;
|
||||||
|
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||||
|
using input_t = typename Ktraits::input_t;
|
||||||
|
using vec_t = typename Ktraits::vec_t;
|
||||||
|
using weight_t = typename Ktraits::weight_t;
|
||||||
|
|
||||||
|
// Shared memory.
|
||||||
|
extern __shared__ char smem_[];
|
||||||
|
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||||
|
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||||
|
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||||
|
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||||
|
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||||
|
|
||||||
|
const int tidx = threadIdx.x;
|
||||||
|
const int batch_id = blockIdx.x;
|
||||||
|
const int channel_id = blockIdx.y;
|
||||||
|
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||||
|
+ channel_id * params.x_c_stride;
|
||||||
|
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||||
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||||
|
+ channel_id * params.out_c_stride;
|
||||||
|
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||||
|
|
||||||
|
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||||
|
if (tidx == 0) {
|
||||||
|
input_t zeros[kNElts] = {0};
|
||||||
|
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
float weight_vals[kWidth];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||||
|
|
||||||
|
constexpr int kChunkSize = kNThreads * kNElts;
|
||||||
|
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
||||||
|
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||||
|
input_t x_vals_load[2 * kNElts] = {0};
|
||||||
|
if constexpr(kIsVecLoad) {
|
||||||
|
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||||
|
} else {
|
||||||
|
__syncthreads();
|
||||||
|
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
||||||
|
}
|
||||||
|
x += kChunkSize;
|
||||||
|
__syncthreads();
|
||||||
|
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||||
|
// the last elements of the previous chunk.
|
||||||
|
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||||
|
__syncthreads();
|
||||||
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||||
|
__syncthreads();
|
||||||
|
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||||
|
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||||
|
|
||||||
|
float x_vals[2 * kNElts];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||||
|
|
||||||
|
float out_vals[kNElts];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNElts; ++i) {
|
||||||
|
out_vals[i] = bias_val;
|
||||||
|
#pragma unroll
|
||||||
|
for (int w = 0; w < kWidth; ++w) {
|
||||||
|
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.silu_activation) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNElts; ++i) {
|
||||||
|
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_t out_vals_store[kNElts];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||||
|
if constexpr(kIsVecLoad) {
|
||||||
|
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||||
|
} else {
|
||||||
|
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
||||||
|
}
|
||||||
|
out += kChunkSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||||
|
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
||||||
|
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||||
|
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||||
|
dim3 grid(params.batch, params.dim);
|
||||||
|
|
||||||
|
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||||
|
|
||||||
|
if (kSmemSize >= 48 * 1024) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
|
#else
|
||||||
|
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
||||||
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
|
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
|
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
|
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
if (params.width == 2) {
|
||||||
|
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.width == 3) {
|
||||||
|
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.width == 4) {
|
||||||
|
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||||
|
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
||||||
|
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
||||||
|
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
||||||
|
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
||||||
|
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
||||||
|
using input_t = input_t_;
|
||||||
|
using weight_t = weight_t_;
|
||||||
|
static constexpr int kNThreads = kNThreads_;
|
||||||
|
static_assert(kNThreads % 32 == 0);
|
||||||
|
static constexpr int kNWarps = kNThreads / 32;
|
||||||
|
static constexpr int kWidth = kWidth_;
|
||||||
|
static constexpr int kChunkSizeL = kChunkSizeL_;
|
||||||
|
static constexpr int kNBytes = sizeof(input_t);
|
||||||
|
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||||
|
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||||
|
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
||||||
|
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
||||||
|
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
||||||
|
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
||||||
|
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
||||||
|
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
||||||
|
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
||||||
|
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
||||||
|
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||||
|
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||||
|
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||||
|
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||||
|
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
||||||
|
// sizeof(typename BlockStoreT::TempStorage)});
|
||||||
|
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Ktraits, bool kHasSeqIdx>
|
||||||
|
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||||
|
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
||||||
|
constexpr int kWidth = Ktraits::kWidth;
|
||||||
|
constexpr int kNThreads = Ktraits::kNThreads;
|
||||||
|
constexpr int kNElts = Ktraits::kNElts;
|
||||||
|
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
||||||
|
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
||||||
|
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||||
|
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||||
|
using input_t = typename Ktraits::input_t;
|
||||||
|
using vec_t = typename Ktraits::vec_t;
|
||||||
|
using weight_t = typename Ktraits::weight_t;
|
||||||
|
|
||||||
|
// Shared memory.
|
||||||
|
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
||||||
|
|
||||||
|
const int batch_id = blockIdx.x;
|
||||||
|
const int chunk_l_id = blockIdx.y;
|
||||||
|
const int chunk_c_id = blockIdx.z;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int l_idx = tid / kNThreadsPerC;
|
||||||
|
const int c_idx = tid % kNThreadsPerC;
|
||||||
|
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||||
|
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||||
|
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
||||||
|
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
||||||
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||||
|
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||||
|
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
||||||
|
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
||||||
|
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
||||||
|
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||||
|
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
||||||
|
// from the previous L-chunk.
|
||||||
|
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
||||||
|
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||||
|
input_t x_vals_load[kNElts] = {0};
|
||||||
|
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||||
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||||
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
||||||
|
}
|
||||||
|
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||||
|
}
|
||||||
|
// Load the elements from the previous chunk that are needed for convolution.
|
||||||
|
if (l_idx < kWidth - 1) {
|
||||||
|
input_t x_vals_load[kNElts] = {0};
|
||||||
|
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
||||||
|
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
||||||
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||||
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
||||||
|
} else if (initial_states != nullptr
|
||||||
|
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
||||||
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||||
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
||||||
|
}
|
||||||
|
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (final_states != nullptr
|
||||||
|
&& l_idx < kWidth - 1
|
||||||
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||||
|
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
||||||
|
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
||||||
|
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
||||||
|
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
||||||
|
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
||||||
|
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
||||||
|
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
||||||
|
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
||||||
|
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
||||||
|
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
||||||
|
static_assert(kNThreadsPerRow <= 32);
|
||||||
|
|
||||||
|
const int row_idx = tid / kNThreadsPerRow;
|
||||||
|
const int col_idx = tid % kNThreadsPerRow;
|
||||||
|
|
||||||
|
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
||||||
|
float weight_vals[kWidth] = {0};
|
||||||
|
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int w = 0; w < kWidth; ++w) {
|
||||||
|
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float x_vals[kWidth - 1 + kLPerThread];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||||
|
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
||||||
|
}
|
||||||
|
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
||||||
|
if constexpr (kHasSeqIdx) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||||
|
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float out_vals[kLPerThread];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kLPerThread; ++i) {
|
||||||
|
out_vals[i] = bias_val;
|
||||||
|
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
||||||
|
#pragma unroll
|
||||||
|
for (int w = 0; w < kWidth; ++w) {
|
||||||
|
if constexpr (!kHasSeqIdx) {
|
||||||
|
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
||||||
|
} else {
|
||||||
|
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||||
|
input_t out_vals_store[kNElts];
|
||||||
|
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
||||||
|
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||||
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||||
|
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
||||||
|
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
||||||
|
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||||
|
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||||
|
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||||
|
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
||||||
|
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
||||||
|
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
||||||
|
dim3 block(Ktraits::kNThreads);
|
||||||
|
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
||||||
|
// if (kSmemSize >= 48 * 1024) {
|
||||||
|
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
|
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
|
// }
|
||||||
|
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
|
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
if (params.width == 2) {
|
||||||
|
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.width == 3) {
|
||||||
|
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.width == 4) {
|
||||||
|
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
|
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
///////
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||||
|
struct Causal_conv1d_update_kernel_traits {
|
||||||
|
using input_t = input_t_;
|
||||||
|
using weight_t = weight_t_;
|
||||||
|
static constexpr int kNThreads = kNThreads_;
|
||||||
|
static constexpr int kWidth = kWidth_;
|
||||||
|
static constexpr int kNBytes = sizeof(input_t);
|
||||||
|
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||||
|
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||||
|
constexpr int kWidth = Ktraits::kWidth;
|
||||||
|
constexpr int kNThreads = Ktraits::kNThreads;
|
||||||
|
using input_t = typename Ktraits::input_t;
|
||||||
|
using weight_t = typename Ktraits::weight_t;
|
||||||
|
|
||||||
|
const int tidx = threadIdx.x;
|
||||||
|
const int batch_id = blockIdx.x;
|
||||||
|
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||||
|
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||||
|
+ channel_id * params.x_c_stride;
|
||||||
|
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
||||||
|
+ channel_id * params.conv_state_c_stride;
|
||||||
|
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||||
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||||
|
+ channel_id * params.out_c_stride;
|
||||||
|
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||||
|
|
||||||
|
float weight_vals[kWidth] = {0};
|
||||||
|
if (channel_id < params.dim) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||||
|
}
|
||||||
|
|
||||||
|
float x_vals[kWidth] = {0};
|
||||||
|
if (channel_id < params.dim) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
||||||
|
x_vals[kWidth - 1] = float(x[0]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
||||||
|
}
|
||||||
|
|
||||||
|
float out_val = bias_val;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
||||||
|
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||||
|
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||||
|
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||||
|
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
||||||
|
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
if (params.width == 2) {
|
||||||
|
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.width == 3) {
|
||||||
|
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.width == 4) {
|
||||||
|
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||||
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2024, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
// clang-format off
|
||||||
|
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct ConvParamsBase {
|
||||||
|
using index_t = uint32_t;
|
||||||
|
|
||||||
|
int batch, dim, seqlen, width;
|
||||||
|
bool silu_activation;
|
||||||
|
|
||||||
|
index_t x_batch_stride;
|
||||||
|
index_t x_c_stride;
|
||||||
|
index_t x_l_stride;
|
||||||
|
index_t weight_c_stride;
|
||||||
|
index_t weight_width_stride;
|
||||||
|
index_t out_batch_stride;
|
||||||
|
index_t out_c_stride;
|
||||||
|
index_t out_l_stride;
|
||||||
|
|
||||||
|
index_t conv_state_batch_stride;
|
||||||
|
index_t conv_state_c_stride;
|
||||||
|
index_t conv_state_l_stride;
|
||||||
|
|
||||||
|
// Common data pointers.
|
||||||
|
void *__restrict__ x_ptr;
|
||||||
|
void *__restrict__ weight_ptr;
|
||||||
|
void *__restrict__ bias_ptr;
|
||||||
|
void *__restrict__ out_ptr;
|
||||||
|
|
||||||
|
void *__restrict__ conv_state_ptr;
|
||||||
|
|
||||||
|
void *__restrict__ seq_idx_ptr;
|
||||||
|
|
||||||
|
// No __restrict__ since initial_states could be the same as final_states.
|
||||||
|
void * initial_states_ptr;
|
||||||
|
index_t initial_states_batch_stride;
|
||||||
|
index_t initial_states_l_stride;
|
||||||
|
index_t initial_states_c_stride;
|
||||||
|
|
||||||
|
void * final_states_ptr;
|
||||||
|
index_t final_states_batch_stride;
|
||||||
|
index_t final_states_l_stride;
|
||||||
|
index_t final_states_c_stride;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ inline T shuffle_xor(T val, int offset) {
|
||||||
|
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||||
|
{
|
||||||
|
return std::max(ilist);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
constexpr T constexpr_min(T a, T b) {
|
||||||
|
return std::min(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ inline T shuffle_xor(T val, int offset) {
|
||||||
|
return __shfl_xor(val, offset);
|
||||||
|
}
|
||||||
|
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||||
|
{
|
||||||
|
return *std::max_element(ilist.begin(), ilist.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
constexpr T constexpr_min(T a, T b) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<int BYTES> struct BytesToType {};
|
||||||
|
|
||||||
|
template<> struct BytesToType<16> {
|
||||||
|
using Type = uint4;
|
||||||
|
static_assert(sizeof(Type) == 16);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<8> {
|
||||||
|
using Type = uint64_t;
|
||||||
|
static_assert(sizeof(Type) == 8);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<4> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
static_assert(sizeof(Type) == 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<2> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
static_assert(sizeof(Type) == 2);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<1> {
|
||||||
|
using Type = uint8_t;
|
||||||
|
static_assert(sizeof(Type) == 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct SumOp {
|
||||||
|
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int THREADS>
|
||||||
|
struct Allreduce {
|
||||||
|
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||||
|
template<typename T, typename Operator>
|
||||||
|
static __device__ inline T run(T x, Operator &op) {
|
||||||
|
constexpr int OFFSET = THREADS / 2;
|
||||||
|
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||||
|
return Allreduce<OFFSET>::run(x, op);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Allreduce<2> {
|
||||||
|
template<typename T, typename Operator>
|
||||||
|
static __device__ inline T run(T x, Operator &op) {
|
||||||
|
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// Inspired by
|
||||||
|
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||||
|
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||||
|
// clang-format off
|
||||||
|
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/// @param COND - a boolean expression to switch by
|
||||||
|
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||||
|
/// @param ... - code to execute for true and false
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```
|
||||||
|
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||||
|
/// some_function<BoolConst>(...);
|
||||||
|
/// });
|
||||||
|
/// ```
|
||||||
|
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||||
|
[&] { \
|
||||||
|
if (COND) { \
|
||||||
|
static constexpr bool CONST_NAME = true; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
static constexpr bool CONST_NAME = false; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
// clang-format off
|
||||||
|
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#endif
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct SSMParamsBase {
|
||||||
|
using index_t = uint32_t;
|
||||||
|
|
||||||
|
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||||
|
int dim_ngroups_ratio;
|
||||||
|
bool is_variable_B;
|
||||||
|
bool is_variable_C;
|
||||||
|
|
||||||
|
bool delta_softplus;
|
||||||
|
|
||||||
|
index_t A_d_stride;
|
||||||
|
index_t A_dstate_stride;
|
||||||
|
index_t B_batch_stride;
|
||||||
|
index_t B_d_stride;
|
||||||
|
index_t B_dstate_stride;
|
||||||
|
index_t B_group_stride;
|
||||||
|
index_t C_batch_stride;
|
||||||
|
index_t C_d_stride;
|
||||||
|
index_t C_dstate_stride;
|
||||||
|
index_t C_group_stride;
|
||||||
|
index_t u_batch_stride;
|
||||||
|
index_t u_d_stride;
|
||||||
|
index_t delta_batch_stride;
|
||||||
|
index_t delta_d_stride;
|
||||||
|
index_t z_batch_stride;
|
||||||
|
index_t z_d_stride;
|
||||||
|
index_t out_batch_stride;
|
||||||
|
index_t out_d_stride;
|
||||||
|
index_t out_z_batch_stride;
|
||||||
|
index_t out_z_d_stride;
|
||||||
|
|
||||||
|
// Common data pointers.
|
||||||
|
void *__restrict__ A_ptr;
|
||||||
|
void *__restrict__ B_ptr;
|
||||||
|
void *__restrict__ C_ptr;
|
||||||
|
void *__restrict__ D_ptr;
|
||||||
|
void *__restrict__ u_ptr;
|
||||||
|
void *__restrict__ delta_ptr;
|
||||||
|
void *__restrict__ delta_bias_ptr;
|
||||||
|
void *__restrict__ out_ptr;
|
||||||
|
void *__restrict__ x_ptr;
|
||||||
|
void *__restrict__ z_ptr;
|
||||||
|
void *__restrict__ out_z_ptr;
|
||||||
|
void *__restrict__ index_ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
|
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||||
|
{
|
||||||
|
return std::max(ilist);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
constexpr T constexpr_min(T a, T b) {
|
||||||
|
return std::min(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||||
|
{
|
||||||
|
return *std::max_element(ilist.begin(), ilist.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
constexpr T constexpr_min(T a, T b) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#define MAX_DSTATE 256
|
||||||
|
|
||||||
|
|
||||||
|
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||||
|
return {a.x + b.x, a.y + b.y};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
||||||
|
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
||||||
|
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<int BYTES> struct BytesToType {};
|
||||||
|
|
||||||
|
template<> struct BytesToType<16> {
|
||||||
|
using Type = uint4;
|
||||||
|
static_assert(sizeof(Type) == 16);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<8> {
|
||||||
|
using Type = uint64_t;
|
||||||
|
static_assert(sizeof(Type) == 8);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<4> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
static_assert(sizeof(Type) == 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<2> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
static_assert(sizeof(Type) == 2);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<1> {
|
||||||
|
using Type = uint8_t;
|
||||||
|
static_assert(sizeof(Type) == 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename scalar_t, int N>
|
||||||
|
struct Converter{
|
||||||
|
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N>
|
||||||
|
struct Converter<at::Half, N>{
|
||||||
|
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||||
|
static_assert(N % 2 == 0);
|
||||||
|
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||||
|
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 800
|
||||||
|
template<int N>
|
||||||
|
struct Converter<at::BFloat16, N>{
|
||||||
|
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||||
|
static_assert(N % 2 == 0);
|
||||||
|
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||||
|
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
template<typename scalar_t> struct SSMScanOp;
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct SSMScanOp<float> {
|
||||||
|
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
||||||
|
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// A stateful callback functor that maintains a running prefix to be applied
|
||||||
|
// during consecutive scan operations.
|
||||||
|
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
||||||
|
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
||||||
|
scan_t running_prefix;
|
||||||
|
// Constructor
|
||||||
|
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
||||||
|
// Callback operator to be entered by the first warp of threads in the block.
|
||||||
|
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||||
|
__device__ scan_t operator()(scan_t block_aggregate) {
|
||||||
|
scan_t old_prefix = running_prefix;
|
||||||
|
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
||||||
|
return old_prefix;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
inline __device__ void load_input(typename Ktraits::input_t *u,
|
||||||
|
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
||||||
|
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
||||||
|
int seqlen) {
|
||||||
|
if constexpr (Ktraits::kIsEvenLen) {
|
||||||
|
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
||||||
|
using vec_t = typename Ktraits::vec_t;
|
||||||
|
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
||||||
|
reinterpret_cast<vec_t*>(u),
|
||||||
|
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
, Ktraits::kNThreads * Ktraits::kNLoads
|
||||||
|
#endif
|
||||||
|
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
inline __device__ void load_index(int *u,
|
||||||
|
int (&u_vals)[Ktraits::kNItems],
|
||||||
|
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
|
||||||
|
int seqlen) {
|
||||||
|
if constexpr (Ktraits::kIsEvenLen) {
|
||||||
|
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
|
||||||
|
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
|
||||||
|
reinterpret_cast<uint4*>(u),
|
||||||
|
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
||||||
|
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
||||||
|
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
||||||
|
int seqlen) {
|
||||||
|
constexpr int kNItems = Ktraits::kNItems;
|
||||||
|
typename Ktraits::input_t B_vals_load[kNItems];
|
||||||
|
if constexpr (Ktraits::kIsEvenLen) {
|
||||||
|
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||||
|
using vec_t = typename Ktraits::vec_t;
|
||||||
|
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||||
|
reinterpret_cast<vec_t*>(Bvar),
|
||||||
|
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||||
|
}
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
||||||
|
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
inline __device__ void store_output(typename Ktraits::input_t *out,
|
||||||
|
const float (&out_vals)[Ktraits::kNItems],
|
||||||
|
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
||||||
|
int seqlen) {
|
||||||
|
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
||||||
|
if constexpr (Ktraits::kIsEvenLen) {
|
||||||
|
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
||||||
|
using vec_t = typename Ktraits::vec_t;
|
||||||
|
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
||||||
|
reinterpret_cast<vec_t*>(out),
|
||||||
|
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
||||||
|
}
|
||||||
|
}
|
||||||
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
// clang-format off
|
||||||
|
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include "selective_scan.h"
|
||||||
|
|
||||||
|
#include <c10/util/BFloat16.h>
|
||||||
|
#include <c10/util/Half.h>
|
||||||
|
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_store.cuh>
|
||||||
|
#include <cub/block/block_scan.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
namespace cub = hipcub;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "selective_scan.h"
|
||||||
|
#include "static_switch.h"
|
||||||
|
|
||||||
|
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||||
|
bool kIsVariableB_, bool kIsVariableC_,
|
||||||
|
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
|
||||||
|
struct Selective_Scan_fwd_kernel_traits {
|
||||||
|
static_assert(kNItems_ % 4 == 0);
|
||||||
|
using input_t = input_t_;
|
||||||
|
using weight_t = weight_t_;
|
||||||
|
static constexpr int kNThreads = kNThreads_;
|
||||||
|
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||||
|
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||||
|
static constexpr int kNItems = kNItems_;
|
||||||
|
static constexpr int kNRows = kNRows_;
|
||||||
|
static constexpr int kNBytes = sizeof(input_t);
|
||||||
|
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||||
|
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||||
|
static_assert(kNItems % kNElts == 0);
|
||||||
|
static constexpr int kNLoads = kNItems / kNElts;
|
||||||
|
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||||
|
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||||
|
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||||
|
static constexpr bool kHasZ = kHasZ_;
|
||||||
|
static constexpr bool kUseIndex = kUseIndex_;
|
||||||
|
|
||||||
|
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
||||||
|
static constexpr int kNLoadsIndex = kNItems / 4;
|
||||||
|
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||||
|
using scan_t = float2;
|
||||||
|
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||||
|
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
||||||
|
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||||
|
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||||
|
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
|
||||||
|
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||||
|
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||||
|
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
|
||||||
|
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||||
|
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||||
|
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
||||||
|
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
||||||
|
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||||
|
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||||
|
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||||
|
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||||
|
sizeof(typename BlockLoadVecT::TempStorage),
|
||||||
|
sizeof(typename BlockLoadIndexT::TempStorage),
|
||||||
|
sizeof(typename BlockLoadIndexVecT::TempStorage),
|
||||||
|
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||||
|
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||||
|
sizeof(typename BlockStoreT::TempStorage),
|
||||||
|
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||||
|
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Ktraits>
|
||||||
|
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||||
|
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||||
|
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||||
|
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||||
|
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||||
|
constexpr bool kUseIndex = Ktraits::kUseIndex;
|
||||||
|
constexpr int kNThreads = Ktraits::kNThreads;
|
||||||
|
constexpr int kNItems = Ktraits::kNItems;
|
||||||
|
constexpr int kNRows = Ktraits::kNRows;
|
||||||
|
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||||
|
using input_t = typename Ktraits::input_t;
|
||||||
|
using weight_t = typename Ktraits::weight_t;
|
||||||
|
using scan_t = typename Ktraits::scan_t;
|
||||||
|
|
||||||
|
// Shared memory.
|
||||||
|
extern __shared__ char smem_[];
|
||||||
|
// cast to lvalue reference of expected type
|
||||||
|
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||||
|
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||||
|
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||||
|
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||||
|
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||||
|
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
|
||||||
|
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||||
|
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||||
|
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||||
|
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
||||||
|
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
||||||
|
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
||||||
|
|
||||||
|
const int batch_id = blockIdx.x;
|
||||||
|
const int dim_id = blockIdx.y;
|
||||||
|
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||||
|
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||||
|
+ dim_id * kNRows * params.u_d_stride;
|
||||||
|
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||||
|
+ dim_id * kNRows * params.delta_d_stride;
|
||||||
|
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
||||||
|
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
||||||
|
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||||
|
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||||
|
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||||
|
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
||||||
|
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
|
||||||
|
|
||||||
|
float D_val[kNRows] = {0};
|
||||||
|
if (params.D_ptr != nullptr) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float delta_bias[kNRows] = {0};
|
||||||
|
if (params.delta_bias_ptr != nullptr) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||||
|
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||||
|
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||||
|
// }
|
||||||
|
|
||||||
|
constexpr int kChunkSize = kNThreads * kNItems;
|
||||||
|
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
||||||
|
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||||
|
int index_vals_load[kNRows][kNItems];
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
if constexpr (!kDirectIO) {
|
||||||
|
if (r > 0) { __syncthreads(); }
|
||||||
|
}
|
||||||
|
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||||
|
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||||
|
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||||
|
if constexpr (kUseIndex) {
|
||||||
|
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (kUseIndex) {
|
||||||
|
index += kChunkSize;
|
||||||
|
}
|
||||||
|
u += kChunkSize;
|
||||||
|
delta += kChunkSize;
|
||||||
|
|
||||||
|
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
|
float u_val = float(u_vals[r][i]);
|
||||||
|
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
||||||
|
if (params.delta_softplus) {
|
||||||
|
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
||||||
|
}
|
||||||
|
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
||||||
|
out_vals[r][i] = D_val[r] * u_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||||
|
weight_t A_val[kNRows];
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
||||||
|
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||||
|
constexpr float kLog2e = M_LOG2E;
|
||||||
|
A_val[r] *= kLog2e;
|
||||||
|
}
|
||||||
|
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
||||||
|
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
||||||
|
// If both B and C vary, this is unused.
|
||||||
|
weight_t BC_val[kNRows];
|
||||||
|
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||||
|
if constexpr (kIsVariableB) {
|
||||||
|
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||||
|
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
|
||||||
|
if constexpr (!kIsVariableC) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (kIsVariableC) {
|
||||||
|
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||||
|
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||||
|
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
|
||||||
|
if constexpr (!kIsVariableB) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (!kIsVariableB && !kIsVariableC) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
||||||
|
scan_t thread_data[kNItems];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
|
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||||
|
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||||
|
|
||||||
|
// Reset A bar for cumulative sequences (Real)
|
||||||
|
if constexpr (kUseIndex) {
|
||||||
|
if (index_vals_load[r][i] == 0) {
|
||||||
|
thread_data[i].x = 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||||
|
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||||
|
thread_data[i] = make_float2(1.f, 0.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Initialize running total
|
||||||
|
scan_t running_prefix;
|
||||||
|
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
||||||
|
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
|
||||||
|
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
||||||
|
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||||
|
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||||
|
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||||
|
);
|
||||||
|
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||||
|
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||||
|
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
|
const weight_t C_val = !kIsVariableC
|
||||||
|
? BC_val[r]
|
||||||
|
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
||||||
|
out_vals[r][i] += thread_data[i].y * C_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||||
|
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||||
|
__syncthreads();
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
if constexpr (!kDirectIO) {
|
||||||
|
if (r > 0) { __syncthreads(); }
|
||||||
|
}
|
||||||
|
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kHasZ) {
|
||||||
|
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||||
|
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||||
|
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||||
|
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||||
|
#pragma unroll
|
||||||
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
|
input_t z_vals[kNItems];
|
||||||
|
__syncthreads();
|
||||||
|
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
|
float z_val = z_vals[i];
|
||||||
|
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Bvar += kChunkSize * 1;
|
||||||
|
Cvar += kChunkSize * 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||||
|
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||||
|
// processing 1 row.
|
||||||
|
constexpr int kNRows = 1;
|
||||||
|
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||||
|
constexpr bool kIsVariableB = true;
|
||||||
|
constexpr bool kIsVariableC = true;
|
||||||
|
constexpr bool kHasZ = true;
|
||||||
|
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||||
|
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
|
||||||
|
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
|
||||||
|
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||||
|
dim3 grid(params.batch, params.dim / kNRows);
|
||||||
|
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||||
|
if (kSmemSize >= 48 * 1024) {
|
||||||
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
|
}
|
||||||
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
if (params.seqlen <= 128) {
|
||||||
|
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.seqlen <= 256) {
|
||||||
|
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.seqlen <= 512) {
|
||||||
|
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.seqlen <= 1024) {
|
||||||
|
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||||
|
} else {
|
||||||
|
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (params.seqlen <= 256) {
|
||||||
|
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.seqlen <= 512) {
|
||||||
|
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||||
|
} else if (params.seqlen <= 1024) {
|
||||||
|
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||||
|
} else {
|
||||||
|
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||||
|
|
||||||
|
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||||
|
if (ITYPE == at::ScalarType::Half) { \
|
||||||
|
using input_t = at::Half; \
|
||||||
|
using weight_t = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||||
|
using input_t = at::BFloat16; \
|
||||||
|
using weight_t = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else if (ITYPE == at::ScalarType::Float) { \
|
||||||
|
using input_t = float; \
|
||||||
|
using weight_t = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename input_t, typename weight_t>
|
||||||
|
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
|
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||||
|
// sizes
|
||||||
|
const size_t batch,
|
||||||
|
const size_t dim,
|
||||||
|
const size_t seqlen,
|
||||||
|
const size_t dstate,
|
||||||
|
const size_t n_groups,
|
||||||
|
const size_t n_chunks,
|
||||||
|
const bool is_variable_B,
|
||||||
|
const bool is_variable_C,
|
||||||
|
// device pointers
|
||||||
|
const torch::Tensor u,
|
||||||
|
const torch::Tensor delta,
|
||||||
|
const torch::Tensor A,
|
||||||
|
const torch::Tensor B,
|
||||||
|
const torch::Tensor C,
|
||||||
|
const torch::Tensor out,
|
||||||
|
const torch::Tensor z,
|
||||||
|
const torch::Tensor out_z,
|
||||||
|
void* D_ptr,
|
||||||
|
void* delta_bias_ptr,
|
||||||
|
void* x_ptr,
|
||||||
|
bool has_z,
|
||||||
|
bool delta_softplus,
|
||||||
|
void* index_ptr) {
|
||||||
|
|
||||||
|
// Reset the parameters
|
||||||
|
memset(¶ms, 0, sizeof(params));
|
||||||
|
|
||||||
|
params.batch = batch;
|
||||||
|
params.dim = dim;
|
||||||
|
params.seqlen = seqlen;
|
||||||
|
params.dstate = dstate;
|
||||||
|
params.n_groups = n_groups;
|
||||||
|
params.n_chunks = n_chunks;
|
||||||
|
params.dim_ngroups_ratio = dim / n_groups;
|
||||||
|
|
||||||
|
params.delta_softplus = delta_softplus;
|
||||||
|
|
||||||
|
params.is_variable_B = is_variable_B;
|
||||||
|
params.is_variable_C = is_variable_C;
|
||||||
|
|
||||||
|
// Set the pointers and strides.
|
||||||
|
params.u_ptr = u.data_ptr();
|
||||||
|
params.delta_ptr = delta.data_ptr();
|
||||||
|
params.A_ptr = A.data_ptr();
|
||||||
|
params.B_ptr = B.data_ptr();
|
||||||
|
params.C_ptr = C.data_ptr();
|
||||||
|
params.D_ptr = D_ptr;
|
||||||
|
params.delta_bias_ptr = delta_bias_ptr;
|
||||||
|
params.out_ptr = out.data_ptr();
|
||||||
|
params.x_ptr = x_ptr;
|
||||||
|
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
||||||
|
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
||||||
|
|
||||||
|
params.index_ptr = index_ptr;
|
||||||
|
|
||||||
|
// All stride are in elements, not bytes.
|
||||||
|
params.A_d_stride = A.stride(0);
|
||||||
|
params.A_dstate_stride = A.stride(1);
|
||||||
|
if (!is_variable_B) {
|
||||||
|
params.B_d_stride = B.stride(0);
|
||||||
|
} else {
|
||||||
|
params.B_batch_stride = B.stride(0);
|
||||||
|
params.B_group_stride = B.stride(1);
|
||||||
|
}
|
||||||
|
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
||||||
|
if (!is_variable_C) {
|
||||||
|
params.C_d_stride = C.stride(0);
|
||||||
|
} else {
|
||||||
|
params.C_batch_stride = C.stride(0);
|
||||||
|
params.C_group_stride = C.stride(1);
|
||||||
|
}
|
||||||
|
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
||||||
|
params.u_batch_stride = u.stride(0);
|
||||||
|
params.u_d_stride = u.stride(1);
|
||||||
|
params.delta_batch_stride = delta.stride(0);
|
||||||
|
params.delta_d_stride = delta.stride(1);
|
||||||
|
if (has_z) {
|
||||||
|
params.z_batch_stride = z.stride(0);
|
||||||
|
params.z_d_stride = z.stride(1);
|
||||||
|
params.out_z_batch_stride = out_z.stride(0);
|
||||||
|
params.out_z_d_stride = out_z.stride(1);
|
||||||
|
}
|
||||||
|
params.out_batch_stride = out.stride(0);
|
||||||
|
params.out_d_stride = out.stride(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::Tensor>
|
||||||
|
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||||
|
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||||
|
const c10::optional<torch::Tensor> &D_,
|
||||||
|
const c10::optional<torch::Tensor> &z_,
|
||||||
|
const c10::optional<torch::Tensor> &delta_bias_,
|
||||||
|
bool delta_softplus,
|
||||||
|
const c10::optional<torch::Tensor> &index_,
|
||||||
|
const c10::optional<torch::Tensor> &x) {
|
||||||
|
auto input_type = u.scalar_type();
|
||||||
|
auto weight_type = A.scalar_type();
|
||||||
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||||
|
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
||||||
|
|
||||||
|
const bool is_variable_B = B.dim() >= 3;
|
||||||
|
const bool is_variable_C = C.dim() >= 3;
|
||||||
|
|
||||||
|
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||||
|
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||||
|
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||||
|
|
||||||
|
TORCH_CHECK(u.is_cuda());
|
||||||
|
TORCH_CHECK(delta.is_cuda());
|
||||||
|
TORCH_CHECK(A.is_cuda());
|
||||||
|
TORCH_CHECK(B.is_cuda());
|
||||||
|
TORCH_CHECK(C.is_cuda());
|
||||||
|
|
||||||
|
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||||
|
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||||
|
|
||||||
|
const auto sizes = u.sizes();
|
||||||
|
const int batch_size = sizes[0];
|
||||||
|
const int dim = sizes[1];
|
||||||
|
const int seqlen = sizes[2];
|
||||||
|
const int dstate = A.size(1);
|
||||||
|
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||||
|
|
||||||
|
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||||
|
|
||||||
|
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||||
|
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||||
|
CHECK_SHAPE(A, dim, dstate);
|
||||||
|
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
||||||
|
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
|
||||||
|
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||||
|
|
||||||
|
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
||||||
|
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||||
|
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||||
|
|
||||||
|
if (D_.has_value()) {
|
||||||
|
auto D = D_.value();
|
||||||
|
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||||
|
TORCH_CHECK(D.is_cuda());
|
||||||
|
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||||
|
CHECK_SHAPE(D, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (delta_bias_.has_value()) {
|
||||||
|
auto delta_bias = delta_bias_.value();
|
||||||
|
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||||
|
TORCH_CHECK(delta_bias.is_cuda());
|
||||||
|
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||||
|
CHECK_SHAPE(delta_bias, dim);
|
||||||
|
}
|
||||||
|
if (index_.has_value()) {
|
||||||
|
auto index = index_.value();
|
||||||
|
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
|
||||||
|
TORCH_CHECK(index.is_cuda());
|
||||||
|
CHECK_SHAPE(index, batch_size, seqlen);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor z, out_z;
|
||||||
|
const bool has_z = z_.has_value();
|
||||||
|
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||||
|
z = z_.value();
|
||||||
|
TORCH_CHECK(z.scalar_type() == input_type);
|
||||||
|
TORCH_CHECK(z.is_cuda());
|
||||||
|
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||||
|
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||||
|
out_z = torch::empty_like(z);
|
||||||
|
|
||||||
|
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||||
|
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||||
|
// at::Tensor out = torch::empty_like(u);
|
||||||
|
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||||
|
at::Tensor out = torch::empty_like(delta);
|
||||||
|
if (x.has_value()){
|
||||||
|
auto _x = x.value();
|
||||||
|
TORCH_CHECK(_x.scalar_type() == weight_type);
|
||||||
|
TORCH_CHECK(_x.is_cuda());
|
||||||
|
TORCH_CHECK(_x.stride(-1) == 1);
|
||||||
|
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
SSMParamsBase params;
|
||||||
|
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||||
|
u, delta, A, B, C, out, z, out_z,
|
||||||
|
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||||
|
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||||
|
x.value().data_ptr(),
|
||||||
|
has_z,
|
||||||
|
delta_softplus,
|
||||||
|
index_.has_value() ? index_.value().data_ptr() : nullptr);
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||||
|
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||||
|
});
|
||||||
|
std::vector<at::Tensor> result = {out, x.value()};
|
||||||
|
if (has_z) { result.push_back(out_z); }
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// Inspired by
|
||||||
|
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||||
|
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/// @param COND - a boolean expression to switch by
|
||||||
|
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||||
|
/// @param ... - code to execute for true and false
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```
|
||||||
|
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||||
|
/// some_function<BoolConst>(...);
|
||||||
|
/// });
|
||||||
|
/// ```
|
||||||
|
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||||
|
[&] { \
|
||||||
|
if (COND) { \
|
||||||
|
constexpr bool CONST_NAME = true; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
constexpr bool CONST_NAME = false; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
1740
csrc/moe/marlin_moe_ops.cu
Normal file
1740
csrc/moe/marlin_moe_ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
12
csrc/moe/marlin_moe_ops.h
Normal file
12
csrc/moe/marlin_moe_ops.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
torch::Tensor marlin_gemm_moe(
|
||||||
|
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||||
|
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||||
|
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||||
|
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||||
|
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||||
|
bool replicate_input, bool apply_weights);
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
#include "moe_ops.h"
|
#include "moe_ops.h"
|
||||||
|
#include "marlin_moe_ops.h"
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||||
// Apply topk softmax to the gating outputs.
|
// Apply topk softmax to the gating outputs.
|
||||||
@@ -7,6 +8,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||||
"token_expert_indices, Tensor gating_output) -> ()");
|
"token_expert_indices, Tensor gating_output) -> ()");
|
||||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
m.def(
|
||||||
|
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||||
|
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||||
|
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||||
|
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||||
|
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||||
|
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
|
|||||||
96
csrc/ops.h
96
csrc/ops.h
@@ -54,21 +54,32 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
|||||||
|
|
||||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
int64_t block_size, torch::Tensor& input_tokens,
|
||||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
torch::Tensor& sampled_token_ids,
|
||||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
torch::Tensor& input_positions,
|
||||||
|
torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping,
|
||||||
|
torch::Tensor& block_tables);
|
||||||
|
|
||||||
|
void advance_step_flashinfer(
|
||||||
|
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||||
|
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||||
|
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const std::vector<int64_t>& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias);
|
const std::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
torch::Tensor aqlm_dequant(
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes);
|
const std::vector<int64_t>& codebook_partition_sizes);
|
||||||
|
|
||||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
@@ -83,6 +94,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
|
|
||||||
|
namespace machete {
|
||||||
|
|
||||||
|
std::vector<std::string> supported_schedules(
|
||||||
|
vllm::ScalarTypeTorchPtr const& btype);
|
||||||
|
|
||||||
|
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||||
|
vllm::ScalarTypeTorchPtr const& btype,
|
||||||
|
c10::optional<torch::Tensor> const& scales,
|
||||||
|
c10::optional<torch::Tensor> const& zeros,
|
||||||
|
c10::optional<int64_t> group_size,
|
||||||
|
c10::optional<torch::Tensor> const& C,
|
||||||
|
c10::optional<double> alpha, c10::optional<double> beta,
|
||||||
|
c10::optional<std::string> schedule);
|
||||||
|
|
||||||
|
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||||
|
vllm::ScalarTypeTorchPtr const& btype);
|
||||||
|
|
||||||
|
}; // namespace machete
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor& b_scales,
|
torch::Tensor& b_scales,
|
||||||
@@ -104,9 +134,26 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& perm, c10::SymInt size_k,
|
||||||
|
c10::SymInt size_n, int64_t num_bits);
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||||
int64_t size_n, int64_t num_bits);
|
int64_t size_n, int64_t num_bits);
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||||
|
c10::SymInt size_k, c10::SymInt size_n,
|
||||||
|
int64_t num_bits);
|
||||||
|
|
||||||
|
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||||
|
int64_t n);
|
||||||
|
|
||||||
|
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||||
|
int64_t type, int64_t row);
|
||||||
|
|
||||||
|
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||||
|
int64_t row);
|
||||||
|
|
||||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
@@ -119,6 +166,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||||
torch::Tensor const& b_q_weight,
|
torch::Tensor const& b_q_weight,
|
||||||
torch::Tensor const& s_tok,
|
torch::Tensor const& s_tok,
|
||||||
@@ -134,9 +189,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
|||||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor& scales);
|
torch::Tensor& scales);
|
||||||
|
|
||||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table);
|
|
||||||
|
|
||||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
@@ -159,6 +211,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad);
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> selective_scan_fwd(
|
||||||
|
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||||
|
const torch::Tensor& B, const torch::Tensor& C,
|
||||||
|
const c10::optional<torch::Tensor>& D_,
|
||||||
|
const c10::optional<torch::Tensor>& z_,
|
||||||
|
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||||
|
const c10::optional<torch::Tensor>& index_,
|
||||||
|
const c10::optional<torch::Tensor>& x);
|
||||||
|
|
||||||
|
at::Tensor causal_conv1d_update(const at::Tensor& x,
|
||||||
|
const at::Tensor& conv_state,
|
||||||
|
const at::Tensor& weight,
|
||||||
|
const c10::optional<at::Tensor>& bias_,
|
||||||
|
bool silu_activation);
|
||||||
|
|
||||||
|
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||||
|
const c10::optional<at::Tensor>& bias_,
|
||||||
|
const c10::optional<at::Tensor>& seq_idx_,
|
||||||
|
const c10::optional<at::Tensor>& initial_states_,
|
||||||
|
const c10::optional<at::Tensor>& final_states_out_,
|
||||||
|
bool silu_activation);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
|
|||||||
@@ -12,13 +12,11 @@ namespace prepare_inputs {
|
|||||||
|
|
||||||
//
|
//
|
||||||
template <int const num_threads>
|
template <int const num_threads>
|
||||||
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
__global__ void advance_step_flashattn_kernel(
|
||||||
int block_size, long* input_tokens_ptr,
|
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||||
long const* sampled_token_ids_ptr,
|
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||||
long* input_positions_ptr,
|
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||||
int* seq_lens_ptr, long* slot_mapping_ptr,
|
int64_t const block_tables_stride) {
|
||||||
int const* block_tables_ptr,
|
|
||||||
int64_t const block_tables_stride) {
|
|
||||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
if (blockIdx.x >= num_query_blocks) {
|
if (blockIdx.x >= num_query_blocks) {
|
||||||
@@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_step(int num_seqs, int num_queries, int block_size,
|
__global__ void advance_step_flashinfer_kernel(
|
||||||
torch::Tensor& input_tokens, // type: long
|
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||||
torch::Tensor& sampled_token_ids, // type: long
|
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||||
torch::Tensor& input_positions, // type: long
|
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||||
torch::Tensor& seq_lens, // type: int
|
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||||
torch::Tensor& slot_mapping, // type: long
|
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||||
torch::Tensor& block_tables) { // type: int
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
|
if (blockIdx.x < num_query_blocks) {
|
||||||
|
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
|
||||||
|
if (cur_query_id < num_queries) {
|
||||||
|
// Update input_tokens
|
||||||
|
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||||
|
|
||||||
|
int seq_len = seq_lens_ptr[cur_query_id];
|
||||||
|
int next_seq_len = seq_len + 1;
|
||||||
|
int next_input_pos = next_seq_len - 1;
|
||||||
|
|
||||||
|
// Update seq_lens
|
||||||
|
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||||
|
// Update input_positions
|
||||||
|
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||||
|
|
||||||
|
int const* seq_block_tables_ptr =
|
||||||
|
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||||
|
|
||||||
|
int block_index = next_input_pos / block_size;
|
||||||
|
int block_offset = next_input_pos % block_size;
|
||||||
|
|
||||||
|
// Update paged_kv_last_page_len
|
||||||
|
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||||
|
|
||||||
|
int slot_num =
|
||||||
|
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||||
|
// Update slot_mapping
|
||||||
|
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||||
|
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||||
|
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||||
|
int* block_table_bound_ptr) {
|
||||||
|
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
|
||||||
|
// Update paged_kv_indptr
|
||||||
|
if (idx < num_queries) {
|
||||||
|
int sum = 0;
|
||||||
|
for (int i = 0; i <= idx; ++i) {
|
||||||
|
sum += block_table_bound_ptr[i];
|
||||||
|
}
|
||||||
|
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void advance_step_flashinfer_indices_kernel(
|
||||||
|
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||||
|
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
|
||||||
|
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||||
|
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
int row = idx / block_tables_stride;
|
||||||
|
int col = idx % block_tables_stride;
|
||||||
|
|
||||||
|
if (row < num_queries && col < block_table_bound_ptr[row]) {
|
||||||
|
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
|
||||||
|
block_tables_ptr[row * block_tables_stride + col];
|
||||||
|
}
|
||||||
|
// if cudagraph, fill padded seqs with the last valid seq's indptr
|
||||||
|
if (num_queries < row && row <= num_seqs) {
|
||||||
|
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||||
|
torch::Tensor& input_tokens, // type: long
|
||||||
|
torch::Tensor& sampled_token_ids, // type: long
|
||||||
|
torch::Tensor& input_positions, // type: long
|
||||||
|
torch::Tensor& seq_lens, // type: int
|
||||||
|
torch::Tensor& slot_mapping, // type: long
|
||||||
|
torch::Tensor& block_tables) { // type: int
|
||||||
|
|
||||||
if (logging) {
|
if (logging) {
|
||||||
printf("advance_step:\n");
|
printf("advance_step_flashattn:\n");
|
||||||
printf(" num_seqs = %d\n", num_seqs);
|
printf(" num_seqs = %d\n", num_seqs);
|
||||||
printf(" num_queries = %d\n", num_queries);
|
printf(" num_queries = %d\n", num_queries);
|
||||||
printf(" block_size = %d\n", block_size);
|
printf(" block_size = %d\n", block_size);
|
||||||
@@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size,
|
|||||||
int blocks;
|
int blocks;
|
||||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
|
||||||
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
advance_step_flashattn_kernel<max_threads>
|
||||||
num_seqs, num_queries, block_size,
|
<<<blocks, max_threads, 0, stream>>>(
|
||||||
|
num_seqs, num_queries, block_size,
|
||||||
|
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||||
|
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||||
|
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||||
|
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||||
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
|
block_tables.stride(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step_flashinfer(
|
||||||
|
int num_seqs, int num_queries, int block_size,
|
||||||
|
torch::Tensor& input_tokens, // type: long
|
||||||
|
torch::Tensor& sampled_token_ids, // type: long
|
||||||
|
torch::Tensor& input_positions, // type: long
|
||||||
|
torch::Tensor& seq_lens, // type: int
|
||||||
|
torch::Tensor& slot_mapping, // type: long
|
||||||
|
torch::Tensor& block_tables, // type: int
|
||||||
|
torch::Tensor& paged_kv_indices, // type: int
|
||||||
|
torch::Tensor& paged_kv_indptr, // type: int
|
||||||
|
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||||
|
torch::Tensor& block_table_bound) { // type: int
|
||||||
|
|
||||||
|
if (logging) {
|
||||||
|
printf("advance_step_flashinfer:\n");
|
||||||
|
printf(" num_seqs = %d\n", num_seqs);
|
||||||
|
printf(" num_queries = %d\n", num_queries);
|
||||||
|
printf(" block_size = %d\n", block_size);
|
||||||
|
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
|
||||||
|
}
|
||||||
|
// Verify all tensors
|
||||||
|
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||||
|
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||||
|
// at::kLong);
|
||||||
|
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||||
|
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||||
|
|
||||||
|
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||||
|
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||||
|
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||||
|
at::kInt);
|
||||||
|
|
||||||
|
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||||
|
|
||||||
|
int dev = sampled_token_ids.get_device();
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
|
||||||
|
int blocks;
|
||||||
|
int threads;
|
||||||
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||||
|
if (logging) {
|
||||||
|
printf("launching kernel with %d blocks\n", blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(will): support arbitrary block_tables stride
|
||||||
|
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"multi-step: not enough threads to map block_table to"
|
||||||
|
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
|
||||||
|
"of seqs,",
|
||||||
|
" increasing the block size or take smaller steps.",
|
||||||
|
" num_queries = ", num_queries,
|
||||||
|
" block_tables.stride(0) = ", block_tables.stride(0),
|
||||||
|
" blocks = ", blocks, " max_threads = ", threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
threads, num_seqs, num_queries, block_size,
|
||||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
block_tables.stride(0));
|
block_tables.stride(0),
|
||||||
|
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||||
|
|
||||||
|
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
threads, num_seqs, num_queries,
|
||||||
|
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||||
|
|
||||||
|
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
threads, num_seqs, num_queries,
|
||||||
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
|
block_tables.stride(0),
|
||||||
|
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace prepare_inputs
|
} // namespace prepare_inputs
|
||||||
|
|
||||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
int64_t block_size, torch::Tensor& input_tokens,
|
||||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
torch::Tensor& sampled_token_ids,
|
||||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
torch::Tensor& input_positions,
|
||||||
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
torch::Tensor& seq_lens,
|
||||||
sampled_token_ids, input_positions, seq_lens,
|
torch::Tensor& slot_mapping,
|
||||||
slot_mapping, block_tables);
|
torch::Tensor& block_tables) {
|
||||||
|
prepare_inputs::advance_step_flashattn(
|
||||||
|
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||||
|
input_positions, seq_lens, slot_mapping, block_tables);
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step_flashinfer(
|
||||||
|
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||||
|
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||||
|
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||||
|
prepare_inputs::advance_step_flashinfer(
|
||||||
|
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||||
|
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||||
|
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||||
}
|
}
|
||||||
@@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate the partition sizes.
|
// Accumulate the partition sizes.
|
||||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
|
||||||
int4 cumulative_sizes;
|
int4 cumulative_sizes;
|
||||||
auto cumulative_size = &cumulative_sizes.x;
|
auto cumulative_size = &cumulative_sizes.x;
|
||||||
int i = 0;
|
size_t i = 0;
|
||||||
int last = 0;
|
int last = 0;
|
||||||
assert(codebook_partition_sizes.size(0) <= 4);
|
assert(codebook_partition_sizes.size() <= 4);
|
||||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
|
||||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
*cumulative_size = codebook_partition_sizes[i] + last;
|
||||||
last = *cumulative_size;
|
last = *cumulative_size;
|
||||||
}
|
}
|
||||||
// fill in the rest with unreachable.
|
// fill in the rest with unreachable.
|
||||||
@@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
|||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const std::vector<int64_t>& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias) {
|
const std::optional<torch::Tensor>& bias) {
|
||||||
int4 cumulative_sizes =
|
int4 cumulative_sizes =
|
||||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16)) {
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
@@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
torch::Tensor aqlm_dequant(
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes) {
|
const std::vector<int64_t>& codebook_partition_sizes) {
|
||||||
int4 cumulative_sizes =
|
int4 cumulative_sizes =
|
||||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
||||||
@@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
|||||||
auto in_features = codes.size(1) * 8;
|
auto in_features = codes.size(1) * 8;
|
||||||
auto out_features = codes.size(0);
|
auto out_features = codes.size(0);
|
||||||
|
|
||||||
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
|
||||||
|
codebook_partition_sizes.end(), 0));
|
||||||
|
|
||||||
auto weights = torch::empty({out_features, in_features},
|
auto weights = torch::empty({out_features, in_features},
|
||||||
torch::TensorOptions()
|
torch::TensorOptions()
|
||||||
|
|||||||
@@ -3,7 +3,14 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "../../dispatch_utils.h"
|
#include "../../dispatch_utils.h"
|
||||||
#include "../../reduction_utils.cuh"
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cub/util_type.cuh>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/util_type.hpp>
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
@@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
|||||||
absmax_val = val > absmax_val ? val : absmax_val;
|
absmax_val = val > absmax_val ? val : absmax_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||||
|
float const block_absmax_val_maybe =
|
||||||
|
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||||
__shared__ float block_absmax_val;
|
__shared__ float block_absmax_val;
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
block_absmax_val = block_absmax_val_maybe;
|
block_absmax_val = block_absmax_val_maybe;
|
||||||
|
|||||||
147
csrc/quantization/cutlass_w8a8/Epilogues.md
Normal file
147
csrc/quantization/cutlass_w8a8/Epilogues.md
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# CUTLASS Epilogues
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
|
||||||
|
|
||||||
|
Currently, we only support symmetric quantization for weights,
|
||||||
|
and symmetric and asymmetric quantization for activations.
|
||||||
|
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
|
||||||
|
|
||||||
|
There are 4 epilogues:
|
||||||
|
1. ScaledEpilogue: symmetric quantization for activations, no bias.
|
||||||
|
1. ScaledEpilogueBias: symmetric quantization for activations, supports bias.
|
||||||
|
1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
|
||||||
|
1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
|
||||||
|
|
||||||
|
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
|
||||||
|
Instead, if no bias is passed, the epilogue will use 0 as the bias.
|
||||||
|
That induces a redundant addition operation (and runtime check), but the performance impact is minor.
|
||||||
|
|
||||||
|
## Underlying Linear Algebra
|
||||||
|
|
||||||
|
More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).
|
||||||
|
|
||||||
|
If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
|
||||||
|
|
||||||
|
```math
|
||||||
|
A = s_a (\widehat A - J_a z_a)
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
B = s_b \widehat B
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = A B + C
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \widehat D + C
|
||||||
|
```
|
||||||
|
|
||||||
|
Here, D is the output of the GEMM, and C is the bias.
|
||||||
|
A is the activations and supports asymmetric quantization,
|
||||||
|
and B is the weights and only supports symmetric quantization.
|
||||||
|
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
|
||||||
|
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
|
||||||
|
Additional epilogues would be required to support asymmetric quantization for weights.
|
||||||
|
|
||||||
|
Expanding further, we can calculate $` \widehat D `$ as follows:
|
||||||
|
|
||||||
|
```math
|
||||||
|
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
|
||||||
|
and $` J_a \widehat B `$ is known ahead of time.
|
||||||
|
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.
|
||||||
|
|
||||||
|
## Epilogues
|
||||||
|
|
||||||
|
### ScaledEpilogue
|
||||||
|
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
|
||||||
|
The output of the GEMM is:
|
||||||
|
|
||||||
|
```math
|
||||||
|
\widehat D = \widehat A \widehat B
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \widehat D
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \widehat A \widehat B
|
||||||
|
```
|
||||||
|
|
||||||
|
Epilogue parameters:
|
||||||
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||||
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||||
|
|
||||||
|
### ScaledEpilogueBias
|
||||||
|
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
|
||||||
|
The output of the GEMM is:
|
||||||
|
|
||||||
|
```math
|
||||||
|
\widehat D = \widehat A \widehat B
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \widehat D + C
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \widehat A \widehat B + C
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Epilogue parameters:
|
||||||
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||||
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||||
|
- `bias` is the bias, is always per-channel (row-vector).
|
||||||
|
|
||||||
|
### ScaledEpilogueAzp
|
||||||
|
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
|
||||||
|
The output of the GEMM is:
|
||||||
|
|
||||||
|
```math
|
||||||
|
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \widehat D + C
|
||||||
|
```
|
||||||
|
```math
|
||||||
|
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
|
||||||
|
```
|
||||||
|
|
||||||
|
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
|
||||||
|
That is precomputed and stored in `azp_with_adj` as a row-vector.
|
||||||
|
|
||||||
|
Epilogue parameters:
|
||||||
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||||
|
- Generally this will be per-tensor as the zero-points are per-tensor.
|
||||||
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||||
|
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
|
||||||
|
- `bias` is the bias, is always per-channel (row-vector).
|
||||||
|
|
||||||
|
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
|
||||||
|
|
||||||
|
### ScaledEpilogueAzpPerToken
|
||||||
|
This epilogue computes the asymmetric per-token quantization for activations with bias.
|
||||||
|
|
||||||
|
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
|
||||||
|
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
|
||||||
|
|
||||||
|
Epilogue parameters:
|
||||||
|
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
||||||
|
- Generally this will be per-token as the zero-points are per-token.
|
||||||
|
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
||||||
|
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
|
||||||
|
- `azp` is the zero-point (`z_a`), is per-token (column-vector).
|
||||||
|
- `bias` is the bias, is always per-channel (row-vector).
|
||||||
|
|
||||||
|
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
|
||||||
|
|
||||||
|
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
|
||||||
|
```
|
||||||
|
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
|
||||||
|
```
|
||||||
@@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast {
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
|
||||||
|
template<
|
||||||
|
class ThreadMap,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL
|
||||||
|
>
|
||||||
|
struct VisitorRowOrZeroBroadcast {
|
||||||
|
|
||||||
|
// This struct has been modified to remove null_default (because it's always 0)
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_row = nullptr;
|
||||||
|
StrideMNL dRow = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SharedStorage {};
|
||||||
|
|
||||||
|
// Global load type
|
||||||
|
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||||
|
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||||
|
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
VisitorRowOrZeroBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params_ptr(¶ms) { }
|
||||||
|
|
||||||
|
Params const* params_ptr;
|
||||||
|
|
||||||
|
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||||
|
struct Callbacks : EmptyCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
Callbacks(
|
||||||
|
GTensor&& tC_gRow,
|
||||||
|
RTensor&& tC_rRow,
|
||||||
|
CTensor&& tC_cRow,
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
Params const* params_ptr
|
||||||
|
):
|
||||||
|
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||||
|
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||||
|
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||||
|
n(get<1>(problem_shape)),
|
||||||
|
params_ptr(params_ptr) { }
|
||||||
|
|
||||||
|
GTensor tC_gRow;
|
||||||
|
RTensor tC_rRow;
|
||||||
|
CTensor tC_cRow;
|
||||||
|
Params const* params_ptr;
|
||||||
|
int n;
|
||||||
|
|
||||||
|
// This function is modified from VisitorRowBroadcast
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin_epilogue() {
|
||||||
|
clear(tC_rRow);
|
||||||
|
auto src_v = filter(tC_gRow);
|
||||||
|
auto coord_v = filter(tC_cRow);
|
||||||
|
auto dst_v = filter(tC_rRow);
|
||||||
|
|
||||||
|
if (params_ptr->ptr_row != nullptr) {
|
||||||
|
// In this case we are loading from a row vector and broadcasting
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(src_v); ++i) {
|
||||||
|
bool guard = get<1>(coord_v(i)) < n;
|
||||||
|
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||||
|
dst_v(i), (void const*)&src_v(i), guard);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// In this case we are broadcasting 0
|
||||||
|
VecType filled_vec;
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < VecLength; i++) {
|
||||||
|
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(src_v); ++i) {
|
||||||
|
if (get<1>(coord_v(i)) < n) {
|
||||||
|
dst_v(i) = filled_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE auto // returns an Array
|
||||||
|
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||||
|
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||||
|
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||||
|
return rRow_frg(column_idx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_callbacks(
|
||||||
|
gemm::GemmCoord threadblock_tile_offset,
|
||||||
|
int thread_idx,
|
||||||
|
ProblemShape problem_shape
|
||||||
|
) {
|
||||||
|
Tensor mRow = make_tensor(
|
||||||
|
make_gmem_ptr(params_ptr->ptr_row),
|
||||||
|
problem_shape,
|
||||||
|
params_ptr->dRow);
|
||||||
|
|
||||||
|
// VECTOR, FRAGMENT_COLUMN
|
||||||
|
Tensor tC_gRow = recast<VecType>(
|
||||||
|
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||||
|
)(_,_,_0{},_0{},_0{},_0{});
|
||||||
|
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||||
|
|
||||||
|
// Generate the pred tensor
|
||||||
|
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||||
|
Tensor tC_cRow = outer_partition(
|
||||||
|
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||||
|
Shape<Int<VecLength>>{},
|
||||||
|
(_0{})
|
||||||
|
);
|
||||||
|
|
||||||
|
return Callbacks<
|
||||||
|
decltype(tC_gRow), decltype(tC_rRow),
|
||||||
|
decltype(tC_cRow), ProblemShape>(
|
||||||
|
cute::move(tC_gRow),
|
||||||
|
cute::move(tC_rRow),
|
||||||
|
cute::move(tC_cRow),
|
||||||
|
problem_shape,
|
||||||
|
params_ptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// Column vector broadcast
|
// Column vector broadcast
|
||||||
@@ -217,7 +367,7 @@ template<
|
|||||||
>
|
>
|
||||||
struct VisitorColOrScalarBroadcast {
|
struct VisitorColOrScalarBroadcast {
|
||||||
|
|
||||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||||
// scalar that must be broadcast.
|
// scalar that must be broadcast.
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
Element const* ptr_col = nullptr;
|
Element const* ptr_col = nullptr;
|
||||||
|
|||||||
@@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <template <typename, typename> typename Epilogue,
|
template <template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||||
@@ -87,6 +106,25 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <template <typename, typename> typename Epilogue,
|
template <template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||||
@@ -139,3 +177,22 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
* This class provides the common load descriptors for the
|
||||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
* ScaledEpilogue[...] classes
|
||||||
*/
|
*/
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
struct ScaledEpilogueBase {
|
struct ScaledEpilogueBase {
|
||||||
protected:
|
protected:
|
||||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
template <typename T>
|
||||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
using ColOrScalarLoad =
|
||||||
|
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
template <typename T>
|
||||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
using RowOrScalarLoad =
|
||||||
|
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowOrZeroLoad =
|
||||||
|
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||||
|
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
// This utility function constructs the arguments for the load descriptors
|
||||||
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
|
// scalar cases.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||||
|
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||||
|
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||||
|
return Arguments{data_ptr, tensor.numel() != 1};
|
||||||
|
} else {
|
||||||
|
// it would technically work but no use case as data_ptr is never nullptr
|
||||||
|
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||||
|
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -110,8 +154,8 @@ struct ScaledEpilogue
|
|||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::ScaleA;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::ScaleB;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
cutlass::multiplies, float, float,
|
cutlass::multiplies, float, float,
|
||||||
@@ -131,28 +175,32 @@ struct ScaledEpilogue
|
|||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using ScaleAArgs = typename ScaleA::Arguments;
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
using ScaleBArgs = typename ScaleB::Arguments;
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
|
||||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
return ArgumentType{a_args, evt0_args};
|
||||||
|
|
||||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
|
||||||
|
|
||||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
|
||||||
return evt_compute_args;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||||
|
* This bias can also be used in the per-tensor azp case, where the activation
|
||||||
|
* zero point (azp) is used to compute an azp correction term,
|
||||||
|
* which is folded into the bias.
|
||||||
|
*
|
||||||
|
* The bias tensor must be per-output channel.
|
||||||
|
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||||
|
*/
|
||||||
template <typename ElementD, typename OutputTileThreadMap>
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
struct ScaledEpilogueBias
|
struct ScaledEpilogueBias
|
||||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
private:
|
protected:
|
||||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::ScaleA;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::ScaleB;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
cutlass::multiplies, float, float,
|
cutlass::multiplies, float, float,
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
@@ -164,30 +212,163 @@ struct ScaledEpilogueBias
|
|||||||
cutlass::multiply_add, ElementD, float,
|
cutlass::multiply_add, ElementD, float,
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
|
||||||
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||||
EVTCompute0, Bias>;
|
EVTCompute0, Bias>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& bias) {
|
torch::Tensor const& bias) {
|
||||||
using ScaleAArgs = typename ScaleA::Arguments;
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
using ScaleBArgs = typename ScaleB::Arguments;
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
using BiasArgs = typename Bias::Arguments;
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
|
||||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
return ArgumentType{a_args, evt0_args, bias_args};
|
||||||
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
/*
|
||||||
|
* This epilogue directly supports per-tensor azp in int32 form.
|
||||||
|
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||||
|
* term, which should already be multiplied with the scalar azp.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogueBiasAzp
|
||||||
|
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||||
|
|
||||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
|
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||||
bias_args};
|
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
return evt_compute_args;
|
|
||||||
|
// Compute float(accum - azp_adj), both operands are int32_t
|
||||||
|
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||||
|
EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue supports per-token azp by computing and applying
|
||||||
|
* the correction term using a rank-1 update. If the term were materialized,
|
||||||
|
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||||
|
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||||
|
* point for each row of A.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
|
struct ScaledEpilogueBiasAzpToken
|
||||||
|
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||||
|
|
||||||
|
// Per-token azp term, shape (m,1)
|
||||||
|
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||||
|
|
||||||
|
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||||
|
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute azp * azp_adj
|
||||||
|
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, int32_t, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||||
|
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAcc =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||||
|
EVTComputeAcc>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
torch::Tensor const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||||
|
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
* This class provides the common load descriptors for the
|
||||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
* ScaledEpilogue[...] classes
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
struct ScaledEpilogueBase {
|
struct ScaledEpilogueBase {
|
||||||
protected:
|
protected:
|
||||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
|
|
||||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
template <typename T>
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
template <typename T>
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
// Don't want to support nullptr by default
|
||||||
|
template <typename T, bool EnableNullPtr = false>
|
||||||
|
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
|
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
|
// Don't want to support nullptr by default
|
||||||
|
template <typename T, bool EnableNullPtr = false>
|
||||||
|
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||||
|
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
|
// This utility function constructs the arguments for the load descriptors
|
||||||
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
|
// scalar cases.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||||
|
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||||
|
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||||
|
return Arguments{data_ptr, tensor.numel() != 1};
|
||||||
|
} else {
|
||||||
|
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||||
|
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
|
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||||
|
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||||
|
return Arguments{data_ptr};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -97,8 +139,8 @@ struct ScaledEpilogue
|
|||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::ScaleA;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::ScaleB;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
cutlass::multiplies, float, float,
|
cutlass::multiplies, float, float,
|
||||||
@@ -118,24 +160,32 @@ struct ScaledEpilogue
|
|||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using ScaleA_Args = typename ScaleA::Arguments;
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
using ScaleB_Args = typename ScaleB::Arguments;
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
|
||||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
return ArgumentType{a_args, evt0_args};
|
||||||
|
|
||||||
return ArgumentType{a_args, {b_args}};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||||
|
* This bias can also be used in the per-tensor azp case, where the activation
|
||||||
|
* zero point (azp) is used to compute an azp correction term,
|
||||||
|
* which is folded into the bias.
|
||||||
|
*
|
||||||
|
* The bias tensor must be per-output channel.
|
||||||
|
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||||
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
struct ScaledEpilogueBias
|
struct ScaledEpilogueBias
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::ScaleA;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::ScaleB;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
cutlass::multiplies, float, float,
|
cutlass::multiplies, float, float,
|
||||||
@@ -148,27 +198,160 @@ struct ScaledEpilogueBias
|
|||||||
cutlass::multiply_add, ElementD, float,
|
cutlass::multiply_add, ElementD, float,
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
|
|
||||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using EVTCompute =
|
using EVTCompute =
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||||
|
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||||
|
return ArgumentType{a_args, evt0_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This epilogue directly supports per-tensor azp in int32 form.
|
||||||
|
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||||
|
* term, which should already be multiplied with the scalar azp.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueBiasAzp
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||||
|
|
||||||
|
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||||
|
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp_adj), both operands are int32_t
|
||||||
|
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& bias) {
|
torch::Tensor const& azp_adj,
|
||||||
using ScaleA_Args = typename ScaleA::Arguments;
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
using ScaleB_Args = typename ScaleB::Arguments;
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
using Bias_Args = typename Bias::Arguments;
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||||
Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return ArgumentType{a_args, {b_args}, bias_args};
|
/*
|
||||||
|
* This epilogue supports per-token azp by computing and applying
|
||||||
|
* the correction term using a rank-1 update. If the term were materialized,
|
||||||
|
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||||
|
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||||
|
* point for each row of A.
|
||||||
|
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||||
|
*
|
||||||
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueBiasAzpToken
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||||
|
|
||||||
|
// Per-token azp term, shape (m,1)
|
||||||
|
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||||
|
|
||||||
|
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||||
|
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||||
|
|
||||||
|
// Compute azp * azp_adj
|
||||||
|
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, int32_t, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAzp =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||||
|
|
||||||
|
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||||
|
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::minus, float, int32_t,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeAcc =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||||
|
|
||||||
|
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTComputeScaleB =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||||
|
|
||||||
|
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiply_add, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||||
|
EVTComputeScaleB, Bias>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
torch::Tensor const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||||
|
auto azp_adj_args =
|
||||||
|
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||||
|
|
||||||
|
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||||
|
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||||
|
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||||
|
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -29,6 +29,40 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
c10::optional<torch::Tensor> const& bias);
|
c10::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||||
// CUTLASS FP8 kernels need at least
|
// CUTLASS FP8 kernels need at least
|
||||||
// CUDA 12.0 on SM90 systems (Hopper)
|
// CUDA 12.0 on SM90 systems (Hopper)
|
||||||
@@ -45,18 +79,20 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
int32_t get_sm_version_num() {
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
int32_t major_capability, minor_capability;
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
|
||||||
int32_t major_capability;
|
|
||||||
int32_t minor_capability;
|
|
||||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||||
0);
|
0);
|
||||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||||
0);
|
0);
|
||||||
int32_t version_num = major_capability * 10 + minor_capability;
|
int32_t version_num = major_capability * 10 + minor_capability;
|
||||||
|
return version_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
@@ -77,7 +113,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
if (version_num >= 90) {
|
if (version_num >= 90) {
|
||||||
// Hopper
|
// Hopper
|
||||||
|
|
||||||
@@ -99,3 +135,64 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& azp_adj,
|
||||||
|
c10::optional<torch::Tensor> const& azp,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
b.size(1) == c.size(1));
|
||||||
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
|
// bias, azp, azp_adj are all 1d
|
||||||
|
// bias and azp_adj have n elements, azp has m elements
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||||
|
}
|
||||||
|
if (azp) {
|
||||||
|
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||||
|
}
|
||||||
|
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||||
|
|
||||||
|
// azp & bias types
|
||||||
|
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||||
|
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||||
|
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||||
|
"currently bias dtype must match output dtype ", c.dtype());
|
||||||
|
|
||||||
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
if (version_num >= 90) {
|
||||||
|
// Hopper
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
#else
|
||||||
|
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
#endif
|
||||||
|
} else if (version_num == 89) {
|
||||||
|
// Ada Lovelace
|
||||||
|
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
} else if (version_num >= 80) {
|
||||||
|
// Ampere
|
||||||
|
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
} else {
|
||||||
|
// Turing
|
||||||
|
TORCH_CHECK(version_num >= 75);
|
||||||
|
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,25 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#include "../../reduction_utils.cuh"
|
#ifndef USE_ROCM
|
||||||
|
#include <cub/util_type.cuh>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/util_type.hpp>
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
|
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||||
|
std::numeric_limits<FP8_TYPE>::max();
|
||||||
|
#else
|
||||||
|
#include "amd/hip_float8.h"
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||||
|
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||||
|
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||||
|
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@@ -21,11 +39,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
return old;
|
return old;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
|
||||||
|
|
||||||
template <bool is_scale_inverted>
|
template <bool is_scale_inverted>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||||
float const val, float const scale) {
|
float const scale) {
|
||||||
float x = 0.0f;
|
float x = 0.0f;
|
||||||
if constexpr (is_scale_inverted) {
|
if constexpr (is_scale_inverted) {
|
||||||
x = val * scale;
|
x = val * scale;
|
||||||
@@ -34,7 +50,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
|
#ifndef USE_ROCM
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
|
#else
|
||||||
|
// Use hardware cvt instruction for fp8 on rocm
|
||||||
|
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||||
|
c10::Float8_e4m3fnuz::from_bits());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the absolute maximum m of the input tensor and store
|
// Compute the absolute maximum m of the input tensor and store
|
||||||
@@ -74,8 +96,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale,
|
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,10 +109,10 @@ struct __align__(8) vec4_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
typedef struct __align__(4) {
|
typedef struct __align__(4) {
|
||||||
c10::Float8_e4m3fn x;
|
FP8_TYPE x;
|
||||||
c10::Float8_e4m3fn y;
|
FP8_TYPE y;
|
||||||
c10::Float8_e4m3fn z;
|
FP8_TYPE z;
|
||||||
c10::Float8_e4m3fn w;
|
FP8_TYPE w;
|
||||||
}
|
}
|
||||||
float8x4_t;
|
float8x4_t;
|
||||||
|
|
||||||
@@ -124,7 +145,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool is_scale_inverted>
|
template <typename scalar_t, bool is_scale_inverted>
|
||||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||||
scalar_t const* __restrict__ input,
|
scalar_t const* __restrict__ input,
|
||||||
float const scale,
|
float const scale,
|
||||||
int64_t const num_elems,
|
int64_t const num_elems,
|
||||||
@@ -160,7 +181,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
@@ -175,7 +196,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||||
@@ -184,7 +205,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
int const token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
|
|
||||||
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||||
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||||
|
|
||||||
// For vectorization, token_input and token_output pointers need to be
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
// aligned at 8-byte and 4-byte addresses respectively.
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
@@ -200,7 +221,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||||
|
float const block_absmax_val_maybe =
|
||||||
|
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||||
__shared__ float token_scale;
|
__shared__ float token_scale;
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
if (scale_ub) {
|
if (scale_ub) {
|
||||||
@@ -241,7 +265,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||||
scale.data_ptr<float>(), num_elems);
|
scale.data_ptr<float>(), num_elems);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -261,7 +285,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||||
scale.data_ptr<float>(), num_elems);
|
scale.data_ptr<float>(), num_elems);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -284,7 +308,7 @@ void dynamic_per_token_scaled_fp8_quant(
|
|||||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||||
input.data_ptr<scalar_t>(),
|
input.data_ptr<scalar_t>(),
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
hidden_size);
|
hidden_size);
|
||||||
|
|||||||
531
csrc/quantization/gguf/dequantize.cuh
Normal file
531
csrc/quantization/gguf/dequantize.cuh
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu
|
||||||
|
// Dequant functions
|
||||||
|
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
const int vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
|
v.x = __int2half_rn(vui & 0xF);
|
||||||
|
v.y = __int2half_rn(vui >> 4);
|
||||||
|
|
||||||
|
v = __hsub2(v, __floats2half2_rn(8.0f, 8.0f));
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = __low2half(x[ib].dm);
|
||||||
|
const dfloat m = __high2half(x[ib].dm);
|
||||||
|
|
||||||
|
const int vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
|
v.x = __int2half_rn(vui & 0xF);
|
||||||
|
v.y = __int2half_rn(vui >> 4);
|
||||||
|
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
v = __hadd2(v, {m, m});
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
|
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
|
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
|
v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
|
v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
|
v = __hsub2(v, __floats2half2_rn(16.0f, 16.0f));
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = __low2half(x[ib].dm);
|
||||||
|
const dfloat m = __high2half(x[ib].dm);
|
||||||
|
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
|
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
|
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
|
v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
|
v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
v = __hadd2(v, {m, m});
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
v.x = __int2half_rn(x[ib].qs[iqs + 0]);
|
||||||
|
v.y = __int2half_rn(x[ib].qs[iqs + 1]);
|
||||||
|
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
|
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||||
|
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ib = i/qk; // block index
|
||||||
|
const int iqs = (i%qk)/qr; // quant index
|
||||||
|
const int iybs = i - i%qk; // y block start index
|
||||||
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
|
// dequantize
|
||||||
|
dfloat2 v;
|
||||||
|
dequantize_kernel(vx, ib, iqs, v);
|
||||||
|
|
||||||
|
y[iybs + iqs + 0] = v.x;
|
||||||
|
y[iybs + iqs + y_offset] = v.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int n = tid/32;
|
||||||
|
const int l = tid - 32*n;
|
||||||
|
const int is = 8*n + l/16;
|
||||||
|
|
||||||
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
|
dst_t * y = yy + i*QK_K + 128*n;
|
||||||
|
|
||||||
|
half dall = __low2half(x[i].dm);
|
||||||
|
half dmin = __high2half(x[i].dm);
|
||||||
|
y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)));
|
||||||
|
y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)));
|
||||||
|
y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)));
|
||||||
|
y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
|
const int r = threadIdx.x/4;
|
||||||
|
const int tid = r/2;
|
||||||
|
const int is0 = r%2;
|
||||||
|
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||||
|
const int n = tid / 4;
|
||||||
|
const int j = tid - 4*n;
|
||||||
|
|
||||||
|
uint8_t m = 1 << (4*n + j);
|
||||||
|
int is = 8*n + 2*j + is0;
|
||||||
|
int shift = 2*j;
|
||||||
|
|
||||||
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||||
|
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
|
||||||
|
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
|
||||||
|
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
|
||||||
|
half d_all = x[i].d;
|
||||||
|
half dl = __hmul(d_all, __int2half_rn(us - 32));
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
||||||
|
const uint8_t * q = x[i].qs + 32*n;
|
||||||
|
const uint8_t * hm = x[i].hmask;
|
||||||
|
|
||||||
|
for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||||
|
if (j < 4) {
|
||||||
|
d = q[j] & 63; m = q[j + 4] & 63;
|
||||||
|
} else {
|
||||||
|
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||||
|
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int is = 2*il;
|
||||||
|
const int n = 4;
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
|
const half dall = __low2half(x[i].dm);
|
||||||
|
const half dmin = __high2half(x[i].dm);
|
||||||
|
|
||||||
|
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
||||||
|
|
||||||
|
uint8_t sc, m;
|
||||||
|
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
||||||
|
const half d1 = __hmul(dall, __int2half_rn(sc));
|
||||||
|
const half m1 = __hmul(dmin, __int2half_rn(m));
|
||||||
|
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
||||||
|
const half d2 = __hmul(dall, __int2half_rn(sc));
|
||||||
|
const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||||
|
for (int l = 0; l < n; ++l) {
|
||||||
|
y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
|
||||||
|
y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/16; // il is in 0...3
|
||||||
|
const int ir = tid%16; // ir is in 0...15
|
||||||
|
const int is = 2*il; // is is in 0...6
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||||
|
|
||||||
|
const half dall = __low2half(x[i].dm);
|
||||||
|
const half dmin = __high2half(x[i].dm);
|
||||||
|
|
||||||
|
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
|
||||||
|
const uint8_t * qh = x[i].qh + 2*ir;
|
||||||
|
|
||||||
|
uint8_t sc, m;
|
||||||
|
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
||||||
|
const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m));
|
||||||
|
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
||||||
|
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||||
|
|
||||||
|
uint8_t hm = 1 << (2*il);
|
||||||
|
y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
|
||||||
|
y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
|
||||||
|
hm <<= 1;
|
||||||
|
y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2);
|
||||||
|
y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
const block_q6_K * x = (const block_q6_K *) vx;
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int ip = tid/32; // ip is 0 or 1
|
||||||
|
const int il = tid - 32*ip; // 0...32
|
||||||
|
const int is = 8*ip + il/16;
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||||
|
|
||||||
|
const half d = x[i].d;
|
||||||
|
|
||||||
|
const uint8_t * ql = x[i].ql + 64*ip + il;
|
||||||
|
const uint8_t qh = x[i].qh[32*ip + il];
|
||||||
|
const int8_t * sc = x[i].scales + is;
|
||||||
|
|
||||||
|
y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
||||||
|
y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||||
|
y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||||
|
y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
|
||||||
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
|
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||||
|
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||||
|
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
|
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||||
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||||
|
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
|
||||||
|
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
|
||||||
|
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||||
|
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
||||||
|
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint8_t * qs = x[i].qs + 8*ib;
|
||||||
|
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||||
|
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
||||||
|
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
|
||||||
|
const uint8_t signs = x[i].signs[4*ib + il];
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
||||||
|
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const int i8 = 4*ib+il;
|
||||||
|
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
|
||||||
|
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
|
||||||
|
const float d = __half2float(x[i].d) * (2*(h & 7) + 1);
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
|
const float d = __half2float(x[ib].d);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
||||||
|
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||||
|
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
||||||
|
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
|
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||||
|
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||||
|
switch (type) {
|
||||||
|
case 2:
|
||||||
|
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||||
|
case 3:
|
||||||
|
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
||||||
|
case 6:
|
||||||
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
|
case 7:
|
||||||
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
|
case 8:
|
||||||
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
case 10:
|
||||||
|
return dequantize_row_q2_K_cuda;
|
||||||
|
case 11:
|
||||||
|
return dequantize_row_q3_K_cuda;
|
||||||
|
case 12:
|
||||||
|
return dequantize_row_q4_K_cuda;
|
||||||
|
case 13:
|
||||||
|
return dequantize_row_q5_K_cuda;
|
||||||
|
case 14:
|
||||||
|
return dequantize_row_q6_K_cuda;
|
||||||
|
case 16:
|
||||||
|
return dequantize_row_iq2_xxs_cuda;
|
||||||
|
case 17:
|
||||||
|
return dequantize_row_iq2_xs_cuda;
|
||||||
|
case 18:
|
||||||
|
return dequantize_row_iq3_xxs_cuda;
|
||||||
|
case 19:
|
||||||
|
return dequantize_row_iq1_s_cuda;
|
||||||
|
case 20:
|
||||||
|
return dequantize_row_iq4_nl_cuda;
|
||||||
|
case 21:
|
||||||
|
return dequantize_row_iq3_s_cuda;
|
||||||
|
case 22:
|
||||||
|
return dequantize_row_iq2_s_cuda;
|
||||||
|
case 23:
|
||||||
|
return dequantize_row_iq4_xs_cuda;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user