[Build] Add OpenAI triton_kernels (#28788)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-18 19:45:20 -05:00
committed by GitHub
parent 49ef847aa8
commit 9912b8ccb8
6 changed files with 119 additions and 1 deletions

View File

@@ -8,6 +8,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -15,6 +16,7 @@ logger = init_logger(__name__)
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
assert has_triton_kernels()
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor

View File

@@ -18,6 +18,10 @@ from typing import Any
import regex as re
from typing_extensions import Never
from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
@@ -62,6 +66,35 @@ def import_pynvml():
return pynvml
@cache
def import_triton_kernels():
"""
For convenience, prioritize triton_kernels that is available in
`site-packages`. Use `vllm.third_party.triton_kernels` as a fall-back.
"""
if _has_module("triton_kernels"):
import triton_kernels
logger.debug_once(
f"Loading module triton_kernels from {triton_kernels.__file__}.",
scope="local",
)
elif _has_module("vllm.third_party.triton_kernels"):
import vllm.third_party.triton_kernels as triton_kernels
logger.debug_once(
f"Loading module triton_kernels from {triton_kernels.__file__}.",
scope="local",
)
sys.modules["triton_kernels"] = triton_kernels
else:
logger.info_once(
"triton_kernels unavailable in this build. "
"Please consider installing triton_kernels from "
"https://github.com/triton-lang/triton/tree/main/python/triton_kernels"
)
def import_from_path(module_name: str, file_path: str | os.PathLike):
"""
Import a Python file according to its file path.
@@ -397,7 +430,12 @@ def has_deep_gemm() -> bool:
def has_triton_kernels() -> bool:
"""Whether the optional `triton_kernels` package is available."""
return _has_module("triton_kernels")
is_available = _has_module("triton_kernels") or _has_module(
"vllm.third_party.triton_kernels"
)
if is_available:
import_triton_kernels()
return is_available
def has_tilelang() -> bool: