[doc] use MkDocs collapsible blocks - supplement (#19973)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
@@ -28,27 +28,29 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
|
||||
|
||||
In the very verbose logs, we can see:
|
||||
|
||||
```
|
||||
DEBUG 03-07 03:06:52 [decorators.py:203] Start compiling function <code object forward at 0x7f08acf40c90, file "xxx/vllm/model_executor/models/llama.py", line 339>
|
||||
??? Logs
|
||||
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] Traced files (to be considered for compilation cache):
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/_dynamo/polyfills/builtins.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/container.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/module.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/attention/layer.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/communication_op.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/parallel_state.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/custom_op.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/activation.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/layernorm.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/linear.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/rotary_embedding.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/vocab_parallel_embedding.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/models/llama.py
|
||||
```text
|
||||
DEBUG 03-07 03:06:52 [decorators.py:203] Start compiling function <code object forward at 0x7f08acf40c90, file "xxx/vllm/model_executor/models/llama.py", line 339>
|
||||
|
||||
DEBUG 03-07 03:07:07 [backends.py:462] Computation graph saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py
|
||||
DEBUG 03-07 03:07:07 [wrapper.py:105] Dynamo transformed code saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py
|
||||
```
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] Traced files (to be considered for compilation cache):
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/_dynamo/polyfills/builtins.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/container.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/module.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/attention/layer.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/communication_op.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/parallel_state.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/custom_op.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/activation.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/layernorm.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/linear.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/rotary_embedding.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/vocab_parallel_embedding.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/models/llama.py
|
||||
|
||||
DEBUG 03-07 03:07:07 [backends.py:462] Computation graph saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py
|
||||
DEBUG 03-07 03:07:07 [wrapper.py:105] Dynamo transformed code saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py
|
||||
```
|
||||
|
||||
This is about the Python code compilation, i.e. graph capture by Dynamo. It tries to trace the function with code `xxx/vllm/model_executor/models/llama.py:339`, which is the `forward` function of the model we compile. During the forward pass, there are also other functions called and inlined by Dynamo, as shown by the logs, including some PyTorch functions from `xxx/torch/nn/modules/module.py` (used by PyTorch `nn.Module`, because module attribute access will trigger a function call), some communication / attention / activation functions from vLLM. All the traced files will be considered when we decide the cache directory to use. This way, any code change in the above files will trigger compilation cache miss, and therefore recompilation.
|
||||
|
||||
@@ -99,28 +101,31 @@ This time, Inductor compilation is completely bypassed, and we will load from di
|
||||
|
||||
The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example:
|
||||
|
||||
```
|
||||
vllm serve meta-llama/Llama-3.2-1B --compilation_config '{"compile_sizes": [1, 2, 4, 8]}'
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3.2-1B \
|
||||
--compilation_config '{"compile_sizes": [1, 2, 4, 8]}'
|
||||
```
|
||||
|
||||
Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel.
|
||||
|
||||
When all the shapes are known, `torch.compile` can compare different configs, and often find some better configs to run the kernel. For example, we can see the following log:
|
||||
|
||||
```
|
||||
AUTOTUNE mm(8x2048, 2048x3072)
|
||||
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
|
||||
mm 0.0160 ms 81.6%
|
||||
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
|
||||
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
|
||||
triton_mm_7 0.0203 ms 64.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
|
||||
triton_mm_2 0.0208 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_11 0.0215 ms 60.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
|
||||
SingleProcess AUTOTUNE benchmarking takes 2.0428 seconds and 7.5727 seconds precompiling
|
||||
```
|
||||
??? Logs
|
||||
|
||||
```
|
||||
AUTOTUNE mm(8x2048, 2048x3072)
|
||||
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
|
||||
mm 0.0160 ms 81.6%
|
||||
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
|
||||
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
|
||||
triton_mm_7 0.0203 ms 64.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
|
||||
triton_mm_2 0.0208 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_11 0.0215 ms 60.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
|
||||
SingleProcess AUTOTUNE benchmarking takes 2.0428 seconds and 7.5727 seconds precompiling
|
||||
```
|
||||
|
||||
It means, for a matrix multiplication with shape `8x2048x3072`, `torch.compile` tries triton template with various configs, and it is much faster than the default code (which dispatches to cublas library).
|
||||
|
||||
@@ -136,8 +141,9 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh
|
||||
|
||||
By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`:
|
||||
|
||||
```
|
||||
vllm serve meta-llama/Llama-3.2-1B --compilation-config '{"cudagraph_capture_sizes": [1, 2, 4, 8]}'
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3.2-1B \
|
||||
--compilation-config '{"cudagraph_capture_sizes": [1, 2, 4, 8]}'
|
||||
```
|
||||
|
||||
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
|
||||
|
||||
Reference in New Issue
Block a user