[v1] Add PrefixLM support to FlexAttention backend (#27938)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar, field
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||
|
||||
import torch
|
||||
@@ -1217,6 +1218,19 @@ class ModelConfig:
|
||||
)
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_mm_prefix_lm(self) -> bool:
|
||||
"""Whether to use bidirectional attention for mm positions."""
|
||||
MM_PREFIX_LM_MODELS = (
|
||||
"gemma3",
|
||||
# TODO(Isotr0py): Disable paligemma for now before
|
||||
# we supports soft cap attention for FlexAttention
|
||||
# "paligemma",
|
||||
)
|
||||
if not hasattr(self.hf_config, "model_type"):
|
||||
return False
|
||||
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if self.is_deepseek_mla:
|
||||
|
||||
Reference in New Issue
Block a user