Merge QKV into one linear layer (#15)

This commit is contained in:
Zhuohan Li
2023-04-02 15:23:29 +08:00
committed by GitHub
parent 2c5cd0defe
commit 1f01a18d39
2 changed files with 90 additions and 83 deletions

View File

@@ -53,16 +53,9 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
@@ -75,16 +68,18 @@ class OPTAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
@@ -262,11 +257,7 @@ class OPTForCausalLM(nn.Module):
self.lm_head_weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "fc1.weight"]
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
"v_proj.bias", "fc1.bias"]
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, weights_path: str):
@@ -275,24 +266,35 @@ class OPTForCausalLM(nn.Module):
for name, param in state_dict.items():
if "lm_head_weight" in name:
continue
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in (self._column_parallel_weights
+ self._column_parallel_biases):
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
if "qkv_proj" in name:
shard_size = param.shape[0] // 3
weights_to_concat = []
for weight_name in ["q_proj", "k_proj", "v_proj"]:
weight = np.load(os.path.join(
weights_path, name.replace("qkv_proj", weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)