Optimize model execution with CUDA graph (#1926)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Woosuk Kwon
2023-12-16 21:12:08 -08:00
committed by GitHub
parent eed74a558f
commit 37ca558103
34 changed files with 557 additions and 254 deletions

View File

@@ -178,7 +178,6 @@ class FalconAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
@@ -187,8 +186,7 @@ class FalconAttention(nn.Module):
if self.use_rotary:
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
@@ -266,8 +264,7 @@ class FalconDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
):
) -> torch.Tensor:
residual = hidden_states
if self.config.new_decoder_architecture:
@@ -282,7 +279,6 @@ class FalconDecoderLayer(nn.Module):
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
@@ -311,7 +307,6 @@ class FalconDecoderLayer(nn.Module):
mlp_output += mlp_bias
output = mlp_output + residual
return output
@@ -349,18 +344,15 @@ class FalconModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@@ -389,14 +381,12 @@ class FalconForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
input_metadata,
cache_events,
)
return hidden_states