[Optimization] Implement fused add rmsnorm (#1667)
This commit is contained in:
@@ -159,10 +159,14 @@ class QWenBlock(nn.Module):
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@@ -170,14 +174,11 @@ class QWenBlock(nn.Module):
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class QWenModel(nn.Module):
|
||||
@@ -210,20 +211,22 @@ class QWenModel(nn.Module):
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user