Add miscellaneous updates (#8)
This commit is contained in:
@@ -12,7 +12,7 @@ from cacheflow.models import InputMetadata
|
||||
class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
super(OPTCacheFlowAttention, self).__init__()
|
||||
self.scale = float(scale)
|
||||
|
||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||
@@ -106,8 +106,8 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
output = output.view(-1, num_heads, head_size)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
if input_metadata.num_prompts > 0:
|
||||
num_prompt_tokens = sum(input_metadata.prompt_lens)
|
||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||
if num_prompt_tokens > 0:
|
||||
self.multi_query_kv_attention(
|
||||
output[:num_prompt_tokens],
|
||||
query[:num_prompt_tokens],
|
||||
@@ -126,10 +126,9 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Compute the attention op for generation tokens.
|
||||
start_idx = sum(input_metadata.prompt_lens)
|
||||
self.single_query_cached_kv_attention(
|
||||
output[start_idx:],
|
||||
query[start_idx:],
|
||||
output[num_prompt_tokens:],
|
||||
query[num_prompt_tokens:],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
|
||||
Reference in New Issue
Block a user