Fix figures in design doc (#18612)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -140,22 +140,18 @@ title: vLLM Paged Attention
|
|||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
```
|
```
|
||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="query" width="70%" }
|
{ align="center" alt="query" width="70%" }
|
||||||
<figcaption>
|
</figure>
|
||||||
</figcaption>
|
|
||||||
</figure>
|
|
||||||
|
|
||||||
- Each thread defines its own `q_ptr` which points to the assigned
|
- Each thread defines its own `q_ptr` which points to the assigned
|
||||||
query token data on global memory. For example, if `VEC_SIZE` is 4
|
query token data on global memory. For example, if `VEC_SIZE` is 4
|
||||||
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
||||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="q_vecs" width="70%" }
|
{ align="center" alt="q_vecs" width="70%" }
|
||||||
<figcaption>
|
</figure>
|
||||||
</figcaption>
|
|
||||||
</figure>
|
|
||||||
|
|
||||||
```cpp
|
```cpp
|
||||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
@@ -192,11 +188,9 @@ title: vLLM Paged Attention
|
|||||||
points to key token data based on `k_cache` at assigned block,
|
points to key token data based on `k_cache` at assigned block,
|
||||||
assigned head and assigned token.
|
assigned head and assigned token.
|
||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="key" width="70%" }
|
{ align="center" alt="key" width="70%" }
|
||||||
<figcaption>
|
</figure>
|
||||||
</figcaption>
|
|
||||||
</figure>
|
|
||||||
|
|
||||||
- The diagram above illustrates the memory layout for key data. It
|
- The diagram above illustrates the memory layout for key data. It
|
||||||
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
||||||
@@ -209,11 +203,9 @@ title: vLLM Paged Attention
|
|||||||
elements for one token) that will be processed by 2 threads (one
|
elements for one token) that will be processed by 2 threads (one
|
||||||
thread group) separately.
|
thread group) separately.
|
||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="k_vecs" width="70%" }
|
{ align="center" alt="k_vecs" width="70%" }
|
||||||
<figcaption>
|
</figure>
|
||||||
</figcaption>
|
|
||||||
</figure>
|
|
||||||
|
|
||||||
```cpp
|
```cpp
|
||||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||||
@@ -372,20 +364,14 @@ title: vLLM Paged Attention
|
|||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="value" width="70%" }
|
{ align="center" alt="value" width="70%" }
|
||||||
<figcaption>
|
|
||||||
</figcaption>
|
|
||||||
</figure>
|
</figure>
|
||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="logits_vec" width="50%" }
|
{ align="center" alt="logits_vec" width="50%" }
|
||||||
<figcaption>
|
|
||||||
</figcaption>
|
|
||||||
</figure>
|
</figure>
|
||||||
|
|
||||||
<figure markdown="span">
|
<figure markdown="span">
|
||||||
{ align="center" alt="v_vec" width="70%" }
|
{ align="center" alt="v_vec" width="70%" }
|
||||||
<figcaption>
|
|
||||||
</figcaption>
|
|
||||||
</figure>
|
</figure>
|
||||||
|
|
||||||
- Now we need to retrieve the value data and perform dot multiplication
|
- Now we need to retrieve the value data and perform dot multiplication
|
||||||
|
|||||||
Reference in New Issue
Block a user