[Bugfix][ROCm][GPT-OSS] Use old triton_kernels implementation on ROCm if the new API is not available (#34153)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
5e75a14a66
commit
c60f8e3b49
@@ -19,11 +19,14 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.import_utils import has_triton_kernels
|
from vllm.utils.import_utils import has_triton_kernels
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
use_legacy_triton_kernels = False
|
||||||
|
|
||||||
if has_triton_kernels():
|
if has_triton_kernels():
|
||||||
try:
|
try:
|
||||||
import triton_kernels.swiglu
|
import triton_kernels.swiglu
|
||||||
@@ -38,10 +41,20 @@ if has_triton_kernels():
|
|||||||
from triton_kernels.tensor import (
|
from triton_kernels.tensor import (
|
||||||
BIT,
|
BIT,
|
||||||
Bitmatrix,
|
Bitmatrix,
|
||||||
|
)
|
||||||
|
from triton_kernels.topk import topk
|
||||||
|
|
||||||
|
try:
|
||||||
|
from triton_kernels.tensor import (
|
||||||
SparseMatrix,
|
SparseMatrix,
|
||||||
make_ragged_tensor_metadata,
|
make_ragged_tensor_metadata,
|
||||||
)
|
)
|
||||||
from triton_kernels.topk import topk
|
except ImportError:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
logger.warning_once("Using legacy triton_kernels on ROCm")
|
||||||
|
use_legacy_triton_kernels = True
|
||||||
|
else:
|
||||||
|
raise
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to import Triton kernels. Please make sure your triton "
|
"Failed to import Triton kernels. Please make sure your triton "
|
||||||
@@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix(
|
|||||||
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
||||||
Creates routing data from a bitmatrix representation.
|
Creates routing data from a bitmatrix representation.
|
||||||
"""
|
"""
|
||||||
|
if use_legacy_triton_kernels:
|
||||||
|
from triton_kernels.routing import routing_from_bitmatrix
|
||||||
|
|
||||||
|
return routing_from_bitmatrix(
|
||||||
|
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
|
||||||
|
)
|
||||||
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
||||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||||
@@ -130,6 +149,10 @@ def legacy_routing(
|
|||||||
Replacement for the removed triton_kernels.routing.routing function.
|
Replacement for the removed triton_kernels.routing.routing function.
|
||||||
Computes routing data from gating logits.
|
Computes routing data from gating logits.
|
||||||
"""
|
"""
|
||||||
|
if use_legacy_triton_kernels:
|
||||||
|
from triton_kernels.routing import routing
|
||||||
|
|
||||||
|
return routing(logits, n_expts_act, sm_first=sm_first)
|
||||||
if sm_first:
|
if sm_first:
|
||||||
logits = torch.softmax(logits, dim=-1)
|
logits = torch.softmax(logits, dim=-1)
|
||||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||||
@@ -231,12 +254,23 @@ def triton_kernel_fused_experts(
|
|||||||
)
|
)
|
||||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||||
|
|
||||||
act = FusedActivation(
|
act = (
|
||||||
|
FusedActivation(
|
||||||
FnSpecs(
|
FnSpecs(
|
||||||
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
|
"swiglu",
|
||||||
|
triton_kernels.swiglu.swiglu_fn,
|
||||||
|
("alpha", "limit"),
|
||||||
|
reduction_n=2,
|
||||||
),
|
),
|
||||||
(swiglu_alpha, swiglu_limit),
|
(swiglu_alpha, swiglu_limit),
|
||||||
)
|
)
|
||||||
|
if not use_legacy_triton_kernels
|
||||||
|
else FusedActivation(
|
||||||
|
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||||
|
(swiglu_alpha, swiglu_limit),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
)
|
||||||
gammas = routing_data.gate_scal if routing_data else None
|
gammas = routing_data.gate_scal if routing_data else None
|
||||||
|
|
||||||
matmul_ogs(
|
matmul_ogs(
|
||||||
@@ -296,9 +330,18 @@ def make_routing_data(
|
|||||||
|
|
||||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||||
bitmatrix_shape_max = [n_rows, None]
|
bitmatrix_shape_max = [n_rows, None]
|
||||||
bitmatrix = Bitmatrix(
|
bitmatrix = (
|
||||||
|
Bitmatrix(
|
||||||
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||||
)
|
)
|
||||||
|
if not use_legacy_triton_kernels
|
||||||
|
else Bitmatrix(
|
||||||
|
bitmatrix,
|
||||||
|
shape=bitmatrix_shape,
|
||||||
|
shape_max=bitmatrix_shape_max,
|
||||||
|
scratchpad=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# matmul_ogs expects invalid topk_weights to be -1s
|
# matmul_ogs expects invalid topk_weights to be -1s
|
||||||
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||||
|
|||||||
Reference in New Issue
Block a user