[Misc] DP : Add ExpertTokensMetadata (#20332)

Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-07-09 20:33:14 -04:00
committed by GitHub
parent b7d9e9416f
commit 805d62ca88
12 changed files with 117 additions and 79 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
@@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
@@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
@@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
@@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
):
"""
This function computes the intermediate result of a Mixture of Experts
@@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
"""
raise NotImplementedError
@@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
@@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
expert_tokens_meta=expert_tokens_meta,
)
else:
# The leading output dimension may not be equal to M, so
@@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
expert_tokens_meta=expert_tokens_meta,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,