[Misc] DP : Add ExpertTokensMetadata (#20332)
Signed-off-by: Varun <vsundarr@redhat.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
b7d9e9416f
commit
805d62ca88
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional, final
|
||||
@@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
|
||||
BatchedExperts = "batched_experts",
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertTokensMetadata:
|
||||
"""
|
||||
Metadata regarding expert-token routing.
|
||||
"""
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(expert_num_tokens_list: list[int],
|
||||
device: str) -> "ExpertTokensMetadata":
|
||||
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens_cpu.to(device,
|
||||
non_blocking=True),
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu)
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
@@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
@@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
function.
|
||||
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||
assigned to each expert when using batched experts format input.
|
||||
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
|
||||
ExpertTokensMetadata object containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
@@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
else:
|
||||
# The leading output dimension may not be equal to M, so
|
||||
@@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=curr_a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
|
||||
Reference in New Issue
Block a user