Optimize data movement (#20)
This commit is contained in:
@@ -11,6 +11,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
@@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
assert hidden_act == 'silu'
|
||||
self.act_fn = nn.SiLU()
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
|
||||
gate, up = torch.split(gate_up, 1, dim=-2)
|
||||
gate = gate.squeeze(dim=-2).contiguous()
|
||||
up = up.squeeze(dim=-2).contiguous()
|
||||
x = self.act_fn(gate) * up
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
@@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||
q = q.squeeze(dim=-2).contiguous()
|
||||
k = k.squeeze(dim=-2).contiguous()
|
||||
v = v.squeeze(dim=-2).contiguous()
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
|
||||
Reference in New Issue
Block a user