Add contributing guideline and mypy config (#122)

This commit is contained in:
Woosuk Kwon
2023-05-23 17:58:51 -07:00
committed by GitHub
parent 3f942acfe1
commit a283ec2eec
16 changed files with 128 additions and 44 deletions

View File

@@ -61,7 +61,7 @@ class GPTCacheFlowAttention(nn.Module):
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
attn_bias: xops.AttentionBias,
) -> None:
) -> torch.Tensor:
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
@@ -197,7 +197,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
def forward(
self,
positions: torch.LongTensor, # [num_tokens]
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]

View File

@@ -347,7 +347,7 @@ def _sample_from_generation_tokens(
# Greedy sampling.
assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [next_token_id.item()]
next_token_ids = [int(next_token_id.item())]
parent_seq_ids = seq_ids
else:
# Random sampling.