[v1] Add PrefixLM support to FlexAttention backend (#27938)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-12-07 23:51:36 +08:00
committed by GitHub
parent 541a2ef892
commit b952f4d3c3
16 changed files with 173 additions and 25 deletions

View File

@@ -4,6 +4,7 @@
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
@@ -1217,6 +1218,19 @@ class ModelConfig:
)
return False
@cached_property
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = (
"gemma3",
# TODO(Isotr0py): Disable paligemma for now before
# we supports soft cap attention for FlexAttention
# "paligemma",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla: