[Performance][B200] Fix deepgemm prologue (#27897)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
478ee511de
commit
74a9a9faad
@@ -60,11 +60,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
expert_weight_is_col_major,
|
||||
deepgemm_post_process_fp8_weight_block,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@@ -94,7 +93,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
@@ -846,15 +844,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.allow_deep_gemm:
|
||||
dg_w13_weight, dg_w13_weight_scale_inv = (
|
||||
deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.w13_weight.data,
|
||||
ws=layer.w13_weight_scale_inv.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv
|
||||
)
|
||||
dg_w2_weight, dg_w2_weight_scale_inv = (
|
||||
deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.w2_weight.data,
|
||||
ws=layer.w2_weight_scale_inv.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
)
|
||||
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(
|
||||
dg_w13_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(
|
||||
dg_w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
@@ -990,31 +1004,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
if is_deep_gemm_e8m0_used() and self.block_quant:
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w13_weight.data,
|
||||
layer.w13_weight_scale_inv.data,
|
||||
block_sz,
|
||||
)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w2_weight.data,
|
||||
layer.w2_weight_scale_inv.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv
|
||||
)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
@@ -1037,7 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
@@ -1053,20 +1043,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
experts_impl = (
|
||||
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
|
||||
)
|
||||
logger.debug(
|
||||
"BatchedTritonOrDeepGemmExperts(%s): "
|
||||
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
experts_impl.__name__,
|
||||
self.__class__.__name__,
|
||||
max_num_tokens_per_rank,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
return experts_impl(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
self.moe,
|
||||
|
||||
Reference in New Issue
Block a user