[Kernel] Apply torch.Tag.needs_fixed_stride_order only for torch==2.6.0 (#19346)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// vLLM custom ops
|
// vLLM custom ops
|
||||||
//
|
//
|
||||||
|
|
||||||
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
|
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
|
||||||
|
// so we need
|
||||||
// to override this for many GEMMs with the following tag. Otherwise,
|
// to override this for many GEMMs with the following tag. Otherwise,
|
||||||
// torch.compile will force all input tensors to be contiguous(), which
|
// torch.compile will force all input tensors to be contiguous(), which
|
||||||
// will break many custom ops that require column-major weight matrices.
|
// will break many custom ops that require column-major weight matrices.
|
||||||
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
|
// This was a bug and PyTorch 2.7 has since fixed this.
|
||||||
// to match exact eager-mode strides.
|
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
|
||||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
#define stride_tag at::Tag::needs_fixed_stride_order
|
||||||
|
#else
|
||||||
|
#define stride_tag
|
||||||
|
#endif
|
||||||
|
|
||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||||
@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
|
|||||||
|
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
if is_torch_equal_or_newer("2.7.0"):
|
||||||
|
tags = ()
|
||||||
|
else:
|
||||||
|
tags = (torch.Tag.needs_fixed_stride_order, ),
|
||||||
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
||||||
op_func=mla_decode_fwd_impl,
|
op_func=mla_decode_fwd_impl,
|
||||||
mutates_args=["o"],
|
mutates_args=["o"],
|
||||||
fake_impl=mla_decode_fwd_fake,
|
fake_impl=mla_decode_fwd_fake,
|
||||||
tags=[torch.Tag.needs_fixed_stride_order])
|
tags=tags)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
|||||||
dequant_mxfp4)
|
dequant_mxfp4)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||||
@@ -1056,7 +1056,8 @@ direct_register_custom_op(
|
|||||||
op_func=inplace_fused_experts,
|
op_func=inplace_fused_experts,
|
||||||
mutates_args=["hidden_states"],
|
mutates_args=["hidden_states"],
|
||||||
fake_impl=inplace_fused_experts_fake,
|
fake_impl=inplace_fused_experts_fake,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(() if is_torch_equal_or_newer("2.7.0") else
|
||||||
|
(torch.Tag.needs_fixed_stride_order, )),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1122,7 +1123,8 @@ direct_register_custom_op(
|
|||||||
op_func=outplace_fused_experts,
|
op_func=outplace_fused_experts,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=outplace_fused_experts_fake,
|
fake_impl=outplace_fused_experts_fake,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(() if is_torch_equal_or_newer("2.7.0") else
|
||||||
|
(torch.Tag.needs_fixed_stride_order, )),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user