Revert "[Bugfix] Fix fused MoE IMA (sans chunking) by using int64 for strides" (#34530)

This commit is contained in:
Michael Goin
2026-02-13 13:35:29 -05:00
committed by GitHub
parent 87789c8364
commit bfaa559305

View File

@@ -98,19 +98,19 @@ def fused_moe_kernel_gptq_awq(
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am: tl.int64, stride_am,
stride_ak: tl.int64, stride_ak,
stride_be: tl.int64, stride_be,
stride_bk: tl.int64, stride_bk,
stride_bn: tl.int64, stride_bn,
stride_cm: tl.int64, stride_cm,
stride_cn: tl.int64, stride_cn,
stride_bse: tl.int64, stride_bse,
stride_bsk: tl.int64, stride_bsk,
stride_bsn: tl.int64, stride_bsn,
stride_bze: tl.int64, stride_bze,
stride_bzk: tl.int64, stride_bzk,
stride_bzn: tl.int64, stride_bzn,
block_k_diviable: tl.constexpr, block_k_diviable: tl.constexpr,
group_size: tl.constexpr, group_size: tl.constexpr,
# Meta-parameters # Meta-parameters
@@ -332,20 +332,20 @@ def fused_moe_kernel(
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am: tl.int64, stride_am,
stride_ak: tl.int64, stride_ak,
stride_be: tl.int64, stride_be,
stride_bk: tl.int64, stride_bk,
stride_bn: tl.int64, stride_bn,
stride_cm: tl.int64, stride_cm,
stride_cn: tl.int64, stride_cn,
stride_asm: tl.int64, stride_asm,
stride_ask: tl.int64, stride_ask,
stride_bse: tl.int64, stride_bse,
stride_bsk: tl.int64, stride_bsk,
stride_bsn: tl.int64, stride_bsn,
stride_bbe: tl.int64, # bias expert stride stride_bbe, # bias expert stride
stride_bbn: tl.int64, # bias N stride stride_bbn, # bias N stride
# Block size for block-wise quantization # Block size for block-wise quantization
group_n: tl.constexpr, group_n: tl.constexpr,
group_k: tl.constexpr, group_k: tl.constexpr,