[Bugfix] Remove hardcoded head_size=256 for Deepseek v2 and v3 (#12067)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-01-16 18:11:54 +08:00
committed by GitHub
parent 9aa1519f08
commit dd7c9ad870
4 changed files with 23 additions and 40 deletions

View File

@@ -262,14 +262,8 @@ class DeepseekV2Attention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self.attn = Attention(self.num_local_heads,
256,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
@@ -319,18 +313,14 @@ class DeepseekV2Attention(nn.Module):
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output