[doc] Fold long code blocks to improve readability (#19926)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
@@ -448,27 +448,29 @@ elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
|
||||
```cpp
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
??? Code
|
||||
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
```cpp
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
```
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user