Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
@@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
_resize_cache,
|
||||
extract_required_args)
|
||||
from vllm.scalar_type import scalar_types
|
||||
@@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
@@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
|
||||
problem_sizes1, problem_sizes2, a_map,
|
||||
c_map, global_num_experts, N, K)
|
||||
|
||||
a1q = ops.shuffle_rows(a1q, a_map)
|
||||
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
|
||||
if per_act_token else a1q_scale)
|
||||
a1q = _fp8_perm(a1q, a_map)
|
||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||
expert_offsets = expert_offsets[:-1]
|
||||
|
||||
ab_strides1 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((w1.size(0), ),
|
||||
2 * N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((w1.size(0), ),
|
||||
N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if use_batched_format:
|
||||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
@@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
|
||||
else:
|
||||
# We can't do this inplace because output may point to the same tensor
|
||||
# as c3.
|
||||
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
|
||||
non_blocking=True)
|
||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||
|
||||
|
||||
# TODO (bnell): split class batched vs. non-batched?
|
||||
@@ -211,10 +223,6 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
num_dispatchers: Optional[int] = None,
|
||||
use_batched_format: bool = False,
|
||||
@@ -231,10 +239,6 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@property
|
||||
@@ -314,8 +318,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
self.use_batched_format)
|
||||
@@ -329,10 +332,6 @@ def cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
per_act_token: Optional[bool] = None,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -360,17 +359,6 @@ def cutlass_moe_fp8(
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
@@ -403,10 +391,6 @@ def cutlass_moe_fp8(
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user