Add flagos in MiniCPM-o (#34126)
Signed-off-by: tc-mb <caitianchi@modelbest.cn> Signed-off-by: Vincent-Xiao <vincent.xiao.me@gmail.com> Co-authored-by: Vincent-Xiao <vincent.xiao.me@gmail.com>
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
@@ -75,6 +76,47 @@ from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
if os.getenv("USE_FLAGOS") == "1":
|
||||
import flag_gems
|
||||
|
||||
FLAG_GEMS_CONFIG = [
|
||||
"sort",
|
||||
"sort_stable",
|
||||
"layer_norm",
|
||||
"clamp_",
|
||||
"cos",
|
||||
"embedding",
|
||||
"exp",
|
||||
"exponential_",
|
||||
"full",
|
||||
"gather",
|
||||
"gelu",
|
||||
"index",
|
||||
"le",
|
||||
"lt",
|
||||
"lt_scalar",
|
||||
"masked_fill_",
|
||||
"max",
|
||||
"ones",
|
||||
"pow_scalar",
|
||||
"prod_dim",
|
||||
"rand_like",
|
||||
"reciprocal",
|
||||
"repeat",
|
||||
"scatter",
|
||||
"scatter_",
|
||||
"sin",
|
||||
"sub",
|
||||
"true_divide",
|
||||
"true_divide_",
|
||||
"uniform_",
|
||||
"where_scalar_self",
|
||||
"where_self_out",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
]
|
||||
flag_gems.only_enable(record=False, include=FLAG_GEMS_CONFIG)
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user