[Release 2.10] Update to Torch 2.10 - final release (#30525)
This commit is contained in:
@@ -974,7 +974,7 @@ def enable_batch_invariant_mode():
|
||||
)
|
||||
|
||||
reduced_precision_val = (
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0") else False
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
reduced_precision_val
|
||||
|
||||
@@ -27,9 +27,21 @@ logger = init_logger(__name__)
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
||||
from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix
|
||||
from triton_kernels.tensor import Bitmatrix
|
||||
from triton_kernels.matmul_ogs import (
|
||||
FnSpecs,
|
||||
FusedActivation,
|
||||
GatherIndx,
|
||||
RoutingData,
|
||||
ScatterIndx,
|
||||
matmul_ogs,
|
||||
)
|
||||
from triton_kernels.tensor import (
|
||||
BIT,
|
||||
Bitmatrix,
|
||||
SparseMatrix,
|
||||
make_ragged_tensor_metadata,
|
||||
)
|
||||
from triton_kernels.topk import topk
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
@@ -78,6 +90,58 @@ def pack_bitmatrix(
|
||||
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
|
||||
|
||||
|
||||
def legacy_routing_from_bitmatrix(
|
||||
bitmatrix: "Bitmatrix",
|
||||
expt_scal: torch.Tensor,
|
||||
expt_indx: torch.Tensor,
|
||||
n_expts_tot: int,
|
||||
n_expts_act: int,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
||||
Creates routing data from a bitmatrix representation.
|
||||
"""
|
||||
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
sparse_logits.mask_metadata.col_sum,
|
||||
dispatch_indx.shape[0],
|
||||
)
|
||||
gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
routing_data = RoutingData(
|
||||
gate_scal,
|
||||
ragged_batch_metadata.block_sizes,
|
||||
n_expts_tot,
|
||||
n_expts_act,
|
||||
ragged_batch_metadata,
|
||||
)
|
||||
gather_idx = GatherIndx(combine_indx, dispatch_indx)
|
||||
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
return routing_data, gather_idx, scatter_idx
|
||||
|
||||
|
||||
def legacy_routing(
|
||||
logits: torch.Tensor,
|
||||
n_expts_act: int,
|
||||
sm_first: bool = False,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Replacement for the removed triton_kernels.routing.routing function.
|
||||
Computes routing data from gating logits.
|
||||
"""
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||
return legacy_routing_from_bitmatrix(
|
||||
sparse_logits.mask,
|
||||
sparse_logits.vals,
|
||||
sparse_logits.indx,
|
||||
logits.shape[-1],
|
||||
n_expts_act,
|
||||
)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
@@ -91,7 +155,7 @@ def triton_kernel_moe_forward(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
routing_data, gather_idx, scatter_idx = legacy_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
|
||||
@@ -168,9 +232,10 @@ def triton_kernel_fused_experts(
|
||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
FnSpecs(
|
||||
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
|
||||
),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
2,
|
||||
)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
@@ -232,12 +297,12 @@ def make_routing_data(
|
||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||
bitmatrix_shape_max = [n_rows, None]
|
||||
bitmatrix = Bitmatrix(
|
||||
bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None
|
||||
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||
)
|
||||
|
||||
# matmul_ogs expects invalid topk_weights to be -1s
|
||||
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user