[Core] Optimizing cross-attention QKVParallelLinear computation (#12325)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal>
Co-authored-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal>
This commit is contained in:
Nicolò Lucchesi
2025-03-06 10:37:26 +01:00
committed by GitHub
parent 5d802522a7
commit 69ff99fdcd
4 changed files with 121 additions and 44 deletions

View File

@@ -1227,3 +1227,98 @@ class RowParallelLinear(LinearBase):
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
class QKVCrossParallelLinear(torch.nn.Module):
def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
self.proj["q_proj_decoder"] = ColumnParallelLinear(
input_size=hidden_size,
output_size=total_num_heads * head_size,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.q_proj_decoder")
self.proj["kv_proj_encoder"] = QKVParallelLinear(
hidden_size=hidden_size,
head_size=head_size,
total_num_heads=0,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
})
@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]
@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]
def forward(self, decoder_hidden_states, encoder_hidden_states):
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
k = None
v = None
else:
# Prefill phase, encoder KV cached here.
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
# Split kv in half
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v
def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)