Add docstrings to some modules and classes (#100)

This commit is contained in:
Woosuk Kwon
2023-05-14 22:32:38 -07:00
committed by GitHub
parent 667ba3995c
commit b322fd1607
17 changed files with 166 additions and 31 deletions

View File

@@ -1,3 +1,4 @@
"""Multi-head attention."""
from typing import Optional
import torch
@@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
class GPTCacheFlowAttention(nn.Module):
"""GPT-style multi-head attention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the
generation tokens, and the paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
not use the KV cache.
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
3. Reshape and store the input key and value tensors in the KV cache.
4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV
cache.
5. Output a flattened 1D tensor.
"""
def __init__(self, scale: float) -> None:
super().__init__()
@@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
self.register_buffer('cos_sin_cache', cache, persistent=False)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def forward(
self,