[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)
Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
@@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
if self.use_cutlass:
|
||||
device = layer.w13_weight.device
|
||||
# ab_strides1 and c_strides2 are the same
|
||||
self.ab_strides1_c_strides2 = torch.full(
|
||||
(layer.local_num_experts, ),
|
||||
layer.hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.ab_strides2 = torch.full(
|
||||
(layer.local_num_experts, ),
|
||||
layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.c_strides1 = torch.full(
|
||||
(layer.local_num_experts, ),
|
||||
2 * layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
@@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
)
|
||||
else:
|
||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||
@@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
)
|
||||
|
||||
self.disable_expert_map = (num_dispatchers > 1
|
||||
@@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user