Add miscellaneous updates (#8)

This commit is contained in:
Woosuk Kwon
2023-03-13 13:48:38 -07:00
committed by GitHub
parent e9d3f2ff77
commit cfae35b861
7 changed files with 44 additions and 22 deletions

View File

@@ -12,7 +12,7 @@ from cacheflow.models import InputMetadata
class OPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
super(OPTCacheFlowAttention, self).__init__()
self.scale = float(scale)
self.flash_attn = FlashAttention(softmax_scale=self.scale)
@@ -106,8 +106,8 @@ class OPTCacheFlowAttention(nn.Module):
output = output.view(-1, num_heads, head_size)
# Compute the attention op for prompts.
if input_metadata.num_prompts > 0:
num_prompt_tokens = sum(input_metadata.prompt_lens)
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
@@ -126,10 +126,9 @@ class OPTCacheFlowAttention(nn.Module):
if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
start_idx = sum(input_metadata.prompt_lens)
self.single_query_cached_kv_attention(
output[start_idx:],
query[start_idx:],
output[num_prompt_tokens:],
query[num_prompt_tokens:],
key_cache,
value_cache,
input_metadata)