[Refactor] Remove align block size logic in moe_permute (#33449)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -44,10 +44,8 @@ def benchmark_permute(
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
# output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
@@ -67,7 +65,6 @@ def benchmark_permute(
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
@@ -117,10 +114,8 @@ def benchmark_unpermute(
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
@@ -142,7 +137,6 @@ def benchmark_unpermute(
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
|
||||
@@ -14,12 +14,10 @@ void moe_permute(
|
||||
const torch::Tensor& token_expert_indices, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input, // [permuted_size, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
torch::Tensor& permuted_idx, // [permute_size]
|
||||
torch::Tensor& m_indices) { // [align_expand_m]
|
||||
torch::Tensor& permuted_idx) { // [permute_size]
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
@@ -34,8 +32,6 @@ void moe_permute(
|
||||
"token_expert_indices shape must be same as inv_permuted_idx");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto align_block_size_value =
|
||||
align_block_size.has_value() ? align_block_size.value() : -1;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const long sorter_size =
|
||||
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
|
||||
@@ -73,42 +69,15 @@ void moe_permute(
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// DeepGEMM: use getMIndices kernel to compute
|
||||
// 1) align_expert_first_token_offset (aligned prefix offsets)
|
||||
// 2) m_indices (expert id for each aligned row)
|
||||
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
|
||||
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
|
||||
// expert0: 3->4, expert1: 5->8, expert2: 2->4
|
||||
// align_expert_first_token_offset = [0, 4, 12, 16]
|
||||
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
|
||||
torch::Tensor align_expert_first_token_offset;
|
||||
const int64_t* aligned_expert_first_token_offset_ptr = nullptr;
|
||||
if (align_block_size.has_value()) {
|
||||
align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
aligned_expert_first_token_offset_ptr =
|
||||
get_ptr<int64_t>(align_expert_first_token_offset);
|
||||
}
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset),
|
||||
aligned_expert_first_token_offset_ptr, n_token, valid_num_ptr, n_hidden,
|
||||
topk, n_local_expert, align_block_size_value, stream);
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, stream);
|
||||
});
|
||||
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
|
||||
#else
|
||||
|
||||
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_ids,
|
||||
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
torch::Tensor& inv_permuted_idx, torch::Tensor& permuted_idx) {
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
|
||||
@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
topk_id_ptr, size, expert_map_ptr, num_experts);
|
||||
}
|
||||
|
||||
template <bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset,
|
||||
int* m_indices, const int num_local_expert,
|
||||
const int align_block_size) {
|
||||
int eidx = blockIdx.x;
|
||||
int tidx = threadIdx.x;
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
|
||||
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
|
||||
}
|
||||
__syncthreads();
|
||||
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
|
||||
auto first_token_offset = smem_expert_first_token_offset[eidx];
|
||||
int n_token_in_expert = last_token_offset - first_token_offset;
|
||||
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
// round up to ALIGN_BLOCK_SIZE
|
||||
int64_t accumulate_align_offset = 0;
|
||||
for (int i = 1; i <= eidx + 1; i++) {
|
||||
int n_token = smem_expert_first_token_offset[i] -
|
||||
smem_expert_first_token_offset[i - 1];
|
||||
accumulate_align_offset =
|
||||
accumulate_align_offset + (n_token + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
if (i == eidx) {
|
||||
first_token_offset = accumulate_align_offset;
|
||||
}
|
||||
// last block store align_expert_first_token_offset
|
||||
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
|
||||
align_expert_first_token_offset[i] = accumulate_align_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
|
||||
// update m_indice with expert id
|
||||
m_indices[first_token_offset + idx] = eidx;
|
||||
}
|
||||
}
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream) {
|
||||
int block = 256;
|
||||
int grid = num_local_expert;
|
||||
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
|
||||
if (align_block_size == -1) {
|
||||
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
} else {
|
||||
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t const* expert_first_token_offset,
|
||||
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
int num_local_experts, cudaStream_t stream);
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
#include "moe_permute_unpermute_kernel.inl"
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
|
||||
template <typename T, bool CHECK_SKIPPED>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t const* expert_first_token_offset,
|
||||
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
int num_local_experts) {
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||
@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel(
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int expert_id = sorted_experts[expanded_dest_row];
|
||||
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
|
||||
// aligned_expert_first_token_offset[e] provides the aligned prefix start
|
||||
// for expert e. For non-local experts we map to the end (total aligned M).
|
||||
int64_t aligned_base = 0;
|
||||
int64_t token_offset_in_expert = 0;
|
||||
if (expert_id >= num_local_experts) {
|
||||
aligned_base =
|
||||
__ldg(aligned_expert_first_token_offset + num_local_experts);
|
||||
token_offset_in_expert = 0;
|
||||
} else {
|
||||
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
|
||||
token_offset_in_expert =
|
||||
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
|
||||
}
|
||||
expanded_dest_row = aligned_base + token_offset_in_expert;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t const* expert_first_token_offset,
|
||||
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
int num_local_experts, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows * k;
|
||||
int64_t const threads = 256;
|
||||
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
|
||||
FuncPtr func_map[2][2] = {
|
||||
{&expandInputRowsKernel<T, false, false>,
|
||||
&expandInputRowsKernel<T, false, true>},
|
||||
{&expandInputRowsKernel<T, true, false>,
|
||||
&expandInputRowsKernel<T, true, true>},
|
||||
using FuncPtr = decltype(&expandInputRowsKernel<T, true>);
|
||||
FuncPtr func_map[2] = {
|
||||
&expandInputRowsKernel<T, false>,
|
||||
&expandInputRowsKernel<T, true>,
|
||||
};
|
||||
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
||||
bool is_align_block_size = align_block_size != -1;
|
||||
auto func = func_map[is_check_skip][is_align_block_size];
|
||||
auto func = func_map[is_check_skip];
|
||||
|
||||
func<<<blocks, threads, 0, stream>>>(
|
||||
unpermuted_input, permuted_output, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
||||
expert_first_token_offset, aligned_expert_first_token_offset, num_rows,
|
||||
num_valid_tokens_ptr, cols, k, num_local_experts, align_block_size);
|
||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
|
||||
@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"moe_permute(Tensor input, Tensor topk_ids,"
|
||||
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
|
||||
"int topk, Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
|
||||
"permuted_idx, Tensor! m_indices)->()");
|
||||
"permuted_idx)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
|
||||
@@ -40,10 +40,8 @@ def torch_permute(
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
n_token = hidden_states.shape[0]
|
||||
if expert_map is not None:
|
||||
is_local_expert = expert_map[topk_ids] != -1
|
||||
not_local_expert = expert_map[topk_ids] == -1
|
||||
@@ -70,107 +68,19 @@ def torch_permute(
|
||||
|
||||
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
|
||||
valid_row_idx = []
|
||||
if align_block_size is None:
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
|
||||
permuted_row_size = permuted_hidden_states.shape[0]
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
m_indices[first_token_offset:last_token_offset] = i - 1
|
||||
src_row_id2dst_row_id_map = torch.arange(
|
||||
0, n_token * topk, device="cuda", dtype=torch.int32
|
||||
)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map,
|
||||
dst_row_id2src_row_id_map,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
else:
|
||||
permuted_row_size = (
|
||||
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
permuted_idx = torch.full(
|
||||
(permuted_row_size,),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
|
||||
)
|
||||
align_src_row_id2dst_row_id = torch.empty(
|
||||
n_token * topk, device="cuda", dtype=torch.int32
|
||||
)
|
||||
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
# get align_permuted_hidden_states,
|
||||
# valid row_idx and align_expert_first_token_offset
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
n_token_in_expert = last_token_offset - first_token_offset
|
||||
align_expert_first_token_offset[i] = (
|
||||
align_expert_first_token_offset[i - 1]
|
||||
+ (n_token_in_expert + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
align_first_token_offset = align_expert_first_token_offset[i - 1]
|
||||
align_last_token_offset = align_expert_first_token_offset[i]
|
||||
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
|
||||
first_token_offset : first_token_offset + n_token_in_expert
|
||||
]
|
||||
# store token in current expert with align_first_token_offset
|
||||
permuted_hidden_states[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert,
|
||||
...,
|
||||
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
|
||||
permuted_idx[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert
|
||||
] = dst_row_id2src_row_id_in_expert
|
||||
# set current expert m_indices
|
||||
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
|
||||
valid_row_idx += [
|
||||
i
|
||||
for i in range(
|
||||
align_first_token_offset,
|
||||
align_first_token_offset + n_token_in_expert,
|
||||
)
|
||||
]
|
||||
# get align_src_row_id2dst_row_id
|
||||
for i in range(n_token * topk):
|
||||
eid = sorted_topk_ids[i]
|
||||
if eid >= n_local_expert:
|
||||
# check token not in local expert
|
||||
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
|
||||
continue
|
||||
first_token_offset = expert_first_token_offset[eid]
|
||||
align_first_token_offset = align_expert_first_token_offset[eid]
|
||||
token_offset = i - first_token_offset
|
||||
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
|
||||
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
|
||||
(n_token, topk)
|
||||
)
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id,
|
||||
permuted_idx,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
|
||||
src_row_id2dst_row_id_map = torch.arange(
|
||||
0, n_token * topk, device="cuda", dtype=torch.int32
|
||||
)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map,
|
||||
dst_row_id2src_row_id_map,
|
||||
valid_row_idx,
|
||||
]
|
||||
|
||||
|
||||
def torch_unpermute(
|
||||
@@ -207,7 +117,6 @@ def torch_unpermute(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("align_block_size", [None, 128])
|
||||
def test_moe_permute_unpermute(
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
|
||||
n_expert: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
align_block_size: int | None,
|
||||
):
|
||||
if not moe_permute_unpermute_supported():
|
||||
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||
fill_invalid_expert = 0
|
||||
ep_rank = np.random.randint(0, ep_size)
|
||||
expert_map = None
|
||||
n_local_expert = n_expert
|
||||
@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
|
||||
gold_expert_first_token_offset,
|
||||
gold_inv_permuted_idx,
|
||||
gold_permuted_idx,
|
||||
gold_m_indices,
|
||||
valid_row_idx,
|
||||
) = torch_permute(
|
||||
hidden_states,
|
||||
@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
(
|
||||
@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
|
||||
_,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
m_indices,
|
||||
_,
|
||||
) = moe_permute(
|
||||
hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
|
||||
n_expert=n_expert,
|
||||
n_local_expert=n_local_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
# check expert_first_token_offset
|
||||
@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
|
||||
torch.testing.assert_close(
|
||||
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
|
||||
)
|
||||
# check mindice
|
||||
# current kernel usage assumes deepgemm requires align_block_size
|
||||
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||
if align_block_size is not None:
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(
|
||||
|
||||
@@ -11,8 +11,6 @@ def moe_permute(
|
||||
n_expert: int,
|
||||
n_local_expert: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
permuted_hidden_states: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -27,9 +25,6 @@ def moe_permute(
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||
to workaround DeepGemm unsupported -1 in m_indices
|
||||
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
|
||||
If None, the output tensor will be created in this function.
|
||||
Returns:
|
||||
@@ -37,12 +32,9 @@ def moe_permute(
|
||||
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
|
||||
if original scale not per-tensor scaling
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||
expert_first_token_offset will align up to 'align_block_size'.
|
||||
of each expert for standard grouped gemm.
|
||||
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
|
||||
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
|
||||
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
||||
the group which the j-th row of the LHS belong to.`
|
||||
"""
|
||||
n_token, n_hidden = hidden_states.size()
|
||||
topk = topk_ids.size(1)
|
||||
@@ -50,17 +42,6 @@ def moe_permute(
|
||||
"permue kernel need hidden dim align to 16B"
|
||||
)
|
||||
permuted_row_size = n_token * topk
|
||||
if align_block_size is not None:
|
||||
permuted_row_size = (
|
||||
(
|
||||
permuted_row_size
|
||||
+ n_expert * (align_block_size - 1)
|
||||
+ align_block_size
|
||||
- 1
|
||||
)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
if n_local_expert == -1:
|
||||
n_local_expert = n_expert
|
||||
if permuted_hidden_states is None:
|
||||
@@ -78,12 +59,6 @@ def moe_permute(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
|
||||
m_indices = torch.full(
|
||||
(permuted_row_size,),
|
||||
fill_invalid_expert,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
expert_first_token_offset = torch.empty(
|
||||
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
|
||||
)
|
||||
@@ -105,12 +80,10 @@ def moe_permute(
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
topk,
|
||||
align_block_size,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
permuted_idx,
|
||||
m_indices,
|
||||
)
|
||||
|
||||
if a1q_scale is not None and a1q_scale.dim() > 1:
|
||||
@@ -120,7 +93,7 @@ def moe_permute(
|
||||
a1q_scale,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx.flatten(),
|
||||
m_indices,
|
||||
permuted_idx,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user