[Doc] Clarify FP8 KV cache computation workflow (#31071)

Signed-off-by: westers <steve.westerhouse@origami-analytics.com>
This commit is contained in:
Steve Westerhouse
2025-12-21 18:41:37 -06:00
committed by GitHub
parent 06d490282f
commit 9d701e90d8
2 changed files with 31 additions and 21 deletions

View File

@@ -139,18 +139,18 @@ token data.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
```
<figure markdown="span">
![](../assets/design/paged_attention/query.png){ align="center" alt="query" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/query.png" alt="query" width="70%" />
</p>
Each thread defines its own `q_ptr` which points to the assigned
query token data on global memory. For example, if `VEC_SIZE` is 4
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
total of 128 elements divided into 128 / 4 = 32 vecs.
<figure markdown="span">
![](../assets/design/paged_attention/q_vecs.png){ align="center" alt="q_vecs" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/q_vecs.png" alt="q_vecs" width="70%" />
</p>
```cpp
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
@@ -187,9 +187,9 @@ key token at different iterations. As shown above, that `k_ptr`
points to key token data based on `k_cache` at assigned block,
assigned head and assigned token.
<figure markdown="span">
![](../assets/design/paged_attention/key.png){ align="center" alt="key" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/key.png" alt="key" width="70%" />
</p>
The diagram above illustrates the memory layout for key data. It
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
@@ -202,9 +202,9 @@ iterations. Inside each rectangle, there are a total 32 vecs (128
elements for one token) that will be processed by 2 threads (one
thread group) separately.
<figure markdown="span">
![](../assets/design/paged_attention/k_vecs.png){ align="center" alt="k_vecs" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/k_vecs.png" alt="k_vecs" width="70%" />
</p>
```cpp
K_vec k_vecs[NUM_VECS_PER_THREAD]
@@ -361,17 +361,17 @@ later steps. Now, it should store the normalized softmax result of
## Value
<figure markdown="span">
![](../assets/design/paged_attention/value.png){ align="center" alt="value" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/value.png" alt="value" width="70%" />
</p>
<figure markdown="span">
![](../assets/design/paged_attention/logits_vec.png){ align="center" alt="logits_vec" width="50%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/logits_vec.png" alt="logits_vec" width="50%" />
</p>
<figure markdown="span">
![](../assets/design/paged_attention/v_vec.png){ align="center" alt="v_vec" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/v_vec.png" alt="v_vec" width="70%" />
</p>
Now we need to retrieve the value data and perform dot multiplication
with `logits`. Unlike query and key, there is no thread group