Add docstrings to some modules and classes (#100)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
"""GPT-style multi-head attention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
||||
generation tokens, and the paddings.
|
||||
|
||||
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
The prompts might have different lengths, while the generation tokens always
|
||||
have length 1. The paddings are appended to make the input length a multiple
|
||||
of 8, which is desirable for Tensor Cores.
|
||||
|
||||
The class does the following:
|
||||
1. Perform multi_query_kv_attention for the prompts. This operation does
|
||||
not use the KV cache.
|
||||
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
||||
operations are issued by the cache engine before executing the forward
|
||||
pass of the model, and they are executed asynchronously.
|
||||
3. Reshape and store the input key and value tensors in the KV cache.
|
||||
4. Perform single_query_cached_kv_attention for the generation tokens.
|
||||
This operation reads the previous key and value tensors from the KV
|
||||
cache.
|
||||
5. Output a flattened 1D tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
@@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, rotary_dim]
|
||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user