[DCP][Bugfix][CI] Fix accuracy issue of DCP when using FLASH_ATTN_MLA (#30309)
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
This commit is contained in:
@@ -123,8 +123,11 @@ class CPTestSettings:
|
|||||||
|
|
||||||
CP_TEXT_GENERATION_MODELS = {
|
CP_TEXT_GENERATION_MODELS = {
|
||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||||
|
CPTestSettings.detailed(dcp_multipliers=[1]),
|
||||||
CPTestSettings.detailed(
|
CPTestSettings.detailed(
|
||||||
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
|
dcp_multipliers=[0.5],
|
||||||
|
cp_kv_cache_interleave_size=64,
|
||||||
|
attn_backend="FLASHMLA",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"Qwen/Qwen2.5-1.5B-Instruct": [
|
"Qwen/Qwen2.5-1.5B-Instruct": [
|
||||||
|
|||||||
@@ -105,13 +105,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
|
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
layer_names,
|
layer_names,
|
||||||
vllm_config,
|
vllm_config,
|
||||||
device,
|
device,
|
||||||
FlashAttnMLAMetadata,
|
FlashAttnMLAMetadata,
|
||||||
supports_dcp_with_varlen=True,
|
supports_dcp_with_varlen=(interleave_size == 1),
|
||||||
)
|
)
|
||||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||||
|
|||||||
Reference in New Issue
Block a user