Optimize data movement (#20)

This commit is contained in:
Woosuk Kwon
2023-04-02 00:30:17 -07:00
committed by GitHub
parent 1f01a18d39
commit 897cb2ae28
17 changed files with 275 additions and 135 deletions

View File

@@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):