Optimize data movement (#20)
This commit is contained in:
@@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||
q = q.squeeze(dim=-2).contiguous()
|
||||
k = k.squeeze(dim=-2).contiguous()
|
||||
v = v.squeeze(dim=-2).contiguous()
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
|
||||
Reference in New Issue
Block a user