Optimize data movement (#20)

This commit is contained in:
Woosuk Kwon
2023-04-02 00:30:17 -07:00
committed by GitHub
parent 1f01a18d39
commit 897cb2ae28
17 changed files with 275 additions and 135 deletions

View File

@@ -11,6 +11,7 @@ from torch import nn
from transformers import LlamaConfig
from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
@@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
gate, up = torch.split(gate_up, 1, dim=-2)
gate = gate.squeeze(dim=-2).contiguous()
up = up.squeeze(dim=-2).contiguous()
x = self.act_fn(gate) * up
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
@@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)