[Bugfix] Remove hardcoded head_size=256 for Deepseek v2 and v3 (#12067)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -733,9 +733,12 @@ class ModelConfig:
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2', 'deepseek_v3')):
|
||||
# FlashAttention supports only head_size 32, 64, 128, 256,
|
||||
# we need to pad head_size 192 to 256
|
||||
return 256
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
|
||||
0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user