[Bugfix] Remove hardcoded head_size=256 for Deepseek v2 and v3 (#12067)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -262,14 +262,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.attn = Attention(self.num_heads,
|
||||
# self.qk_head_dim,
|
||||
# self.scaling,
|
||||
# num_kv_heads=self.num_heads)
|
||||
|
||||
# TODO, support head_size 192
|
||||
self.attn = Attention(self.num_local_heads,
|
||||
256,
|
||||
self.qk_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
@@ -319,18 +313,14 @@ class DeepseekV2Attention(nn.Module):
|
||||
k = torch.empty_like(q)
|
||||
k[..., :self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim:] = k_pe
|
||||
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
# padding value to qk_head_dim for alignment
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads,
|
||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -269,14 +269,8 @@ class DeepseekV3Attention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.attn = Attention(self.num_heads,
|
||||
# self.qk_head_dim,
|
||||
# self.scaling,
|
||||
# num_kv_heads=self.num_heads)
|
||||
|
||||
# TODO, support head_size 192
|
||||
self.attn = Attention(self.num_local_heads,
|
||||
256,
|
||||
self.qk_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
@@ -326,18 +320,14 @@ class DeepseekV3Attention(nn.Module):
|
||||
k = torch.empty_like(q)
|
||||
k[..., :self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim:] = k_pe
|
||||
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
# padding value to qk_head_dim for alignment
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads,
|
||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user