[Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (#16198)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: Hongxia Yang <hongxia.yang@amd.com> Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
@@ -40,7 +40,7 @@ from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
from types import MappingProxyType
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||
Optional, Type, TypeVar, Union, cast, overload)
|
||||
Optional, Tuple, Type, TypeVar, Union, cast, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import cachetools
|
||||
@@ -1935,12 +1935,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
tags: Tuple[torch.Tag, ...] = (),
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
@@ -1979,7 +1980,7 @@ def direct_register_custom_op(
|
||||
import torch._custom_op.impl
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.define(op_name + schema_str, tags=tags)
|
||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
Reference in New Issue
Block a user