Add support for ModelOpt MXFP8 MoE models (#35986)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -1204,17 +1204,26 @@ class FusedMoE(CustomOp):
|
||||
# Determine per-tensor weight scale patterns based on variant
|
||||
# Use the dedicated method instead of brittle string matching
|
||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern()
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
|
||||
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||
# weights scales.
|
||||
# Input scales are always per-tensor.
|
||||
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||
# "weight_scale" for per-tensor scales.
|
||||
# NOTE: ModelOpt MXFP8 MoE uses block scales in weight_scale
|
||||
# tensors (quant_method=BLOCK), so those must not be treated
|
||||
# as per-tensor scalars here.
|
||||
is_block_weight_scale = (
|
||||
"weight_scale" in weight_name
|
||||
and quant_method == FusedMoeWeightScaleSupported.BLOCK.value
|
||||
)
|
||||
is_per_tensor = (
|
||||
"weight_scale_2" in weight_name
|
||||
if uses_weight_scale_2
|
||||
else "weight_scale" in weight_name
|
||||
) or "input_scale" in weight_name
|
||||
is_per_tensor = is_per_tensor and not is_block_weight_scale
|
||||
if is_per_tensor:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
|
||||
44
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
Normal file
44
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MxFp8MoeBackend(Enum):
|
||||
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
||||
|
||||
|
||||
def select_mxfp8_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
) -> MxFp8MoeBackend:
|
||||
if config.is_lora_enabled:
|
||||
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
|
||||
|
||||
AVAILABLE_BACKENDS = [
|
||||
MxFp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]
|
||||
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
mapping = {
|
||||
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
logger.info_once(
|
||||
"Using '%s' MxFp8 MoE backend (user-requested).",
|
||||
backend.value,
|
||||
)
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
# Auto-select: only one backend available for now.
|
||||
backend = AVAILABLE_BACKENDS[0]
|
||||
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
|
||||
return backend
|
||||
@@ -9,17 +9,19 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
@@ -28,6 +30,10 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
|
||||
MxFp8MoeBackend,
|
||||
select_mxfp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
@@ -46,6 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
@@ -60,6 +69,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_VALUE_DTYPE,
|
||||
Mxfp8LinearBackend,
|
||||
Mxfp8LinearOp,
|
||||
mxfp8_e4m3_quantize,
|
||||
swizzle_mxfp8_scale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
@@ -86,7 +96,8 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -1487,17 +1498,6 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
|
||||
return 100
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# MXFP8 does not yet support MoE models
|
||||
if isinstance(layer, FusedMoE):
|
||||
raise NotImplementedError(
|
||||
"MXFP8 quantization does not yet support MoE models. "
|
||||
"Please use FP8 or NVFP4 quantization for MoE models."
|
||||
)
|
||||
return super().get_quant_method(layer, prefix)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
@@ -1699,8 +1699,351 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||
"""FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptMxFp8Config,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(moe_config)
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.is_checkpoint_mxfp8_serialized
|
||||
|
||||
# Select MXFP8 MoE backend
|
||||
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
if hidden_size % MXFP8_BLOCK_SIZE != 0:
|
||||
raise ValueError(
|
||||
f"MXFP8 MoE requires hidden_size divisible by {MXFP8_BLOCK_SIZE}, "
|
||||
f"got {hidden_size}."
|
||||
)
|
||||
if intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0:
|
||||
raise ValueError(
|
||||
"MXFP8 MoE requires intermediate_size_per_partition divisible by "
|
||||
f"{MXFP8_BLOCK_SIZE}, got {intermediate_size_per_partition}."
|
||||
)
|
||||
|
||||
layer.num_experts = num_experts
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
# GEMM 1 weights: [E, (2I or I), H]
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=MXFP8_VALUE_DTYPE,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
# GEMM 2 weights: [E, H, I]
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=MXFP8_VALUE_DTYPE,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
# Per-block (K=32) E8M0 scales.
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size // MXFP8_BLOCK_SIZE,
|
||||
dtype=MXFP8_SCALE_DTYPE,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
|
||||
dtype=MXFP8_SCALE_DTYPE,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Ensure the generic MoE weight-loader treats these as block scales.
|
||||
set_weight_attrs(
|
||||
layer.w13_weight_scale,
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.w2_weight_scale,
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_weight_dtypes(layer: torch.nn.Module) -> None:
|
||||
"""Validate weight and scale dtypes before processing."""
|
||||
expected = {
|
||||
"w13_weight": MXFP8_VALUE_DTYPE,
|
||||
"w2_weight": MXFP8_VALUE_DTYPE,
|
||||
"w13_weight_scale": MXFP8_SCALE_DTYPE,
|
||||
"w2_weight_scale": MXFP8_SCALE_DTYPE,
|
||||
}
|
||||
for name, expected_dtype in expected.items():
|
||||
actual = getattr(layer, name).dtype
|
||||
if actual != expected_dtype:
|
||||
raise ValueError(
|
||||
f"Expected {name} dtype {expected_dtype}, got {actual}."
|
||||
)
|
||||
|
||||
def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None:
|
||||
"""Shuffle weights and scales into FlashInfer TRTLLM MXFP8 layout."""
|
||||
from flashinfer import (
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
|
||||
epilogue_tile_m = 128
|
||||
num_experts = layer.w13_weight.shape[0]
|
||||
is_gated = self.moe.is_act_and_mul
|
||||
intermediate_size_factor = 2 if is_gated else 1
|
||||
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_scale = layer.w13_weight_scale.data
|
||||
if is_gated:
|
||||
# FI TRTLLM gated kernels use W31 ordering. Model checkpoints store
|
||||
# gated projection as W13, so convert once before shuffling.
|
||||
w13_weight = swap_w13_to_w31(w13_weight)
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
w13_weight_shuffled = []
|
||||
w2_weight_shuffled = []
|
||||
w13_scale_shuffled = []
|
||||
w2_scale_shuffled = []
|
||||
for i in range(num_experts):
|
||||
w13_i = w13_weight[i].reshape(
|
||||
intermediate_size_factor * layer.intermediate_size_per_partition, -1
|
||||
)
|
||||
w13_sf_i = w13_scale[i].reshape(
|
||||
intermediate_size_factor * layer.intermediate_size_per_partition, -1
|
||||
)
|
||||
if is_gated:
|
||||
# Reorder rows for gated activation layout expected by TRTLLM.
|
||||
w13_i = reorder_rows_for_gated_act_gemm(w13_i.clone())
|
||||
w13_sf_i = reorder_rows_for_gated_act_gemm(w13_sf_i.clone())
|
||||
|
||||
w13_shuffled_i = shuffle_matrix_a(w13_i.view(torch.uint8), epilogue_tile_m)
|
||||
w2_shuffled_i = shuffle_matrix_a(
|
||||
layer.w2_weight.data[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
w13_weight_shuffled.append(
|
||||
w13_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE)
|
||||
)
|
||||
w2_weight_shuffled.append(
|
||||
w2_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE)
|
||||
)
|
||||
w13_sf_shuffled_i = shuffle_matrix_sf_a(
|
||||
w13_sf_i.view(torch.uint8).reshape(
|
||||
intermediate_size_factor * layer.intermediate_size_per_partition,
|
||||
-1,
|
||||
),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
w2_sf_shuffled_i = shuffle_matrix_sf_a(
|
||||
layer.w2_weight_scale.data[i]
|
||||
.view(torch.uint8)
|
||||
.reshape(layer.hidden_size, -1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
w13_scale_shuffled.append(
|
||||
w13_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE)
|
||||
)
|
||||
w2_scale_shuffled.append(
|
||||
w2_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE)
|
||||
)
|
||||
|
||||
replace_parameter(
|
||||
layer, "w13_weight", torch.stack(w13_weight_shuffled).contiguous()
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight", torch.stack(w2_weight_shuffled).contiguous()
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.stack(w13_scale_shuffled).contiguous(),
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w2_weight_scale",
|
||||
torch.stack(w2_scale_shuffled).contiguous(),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
self._check_weight_dtypes(layer)
|
||||
self._shuffle_weights_for_trtllm(layer)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# TRTLLM MXFP8 path is monolithic and does not use modular kernel config.
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from flashinfer.fused_moe.core import (
|
||||
ActivationType,
|
||||
Fp8QuantizationType,
|
||||
)
|
||||
|
||||
assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend."
|
||||
)
|
||||
|
||||
supported_activations = [MoEActivation.SILU]
|
||||
if layer.activation not in supported_activations:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer TRTLLM MXFP8 MoE supports only "
|
||||
f"{supported_activations}, got {layer.activation}."
|
||||
)
|
||||
|
||||
# Map vLLM MoEActivation to FlashInfer ActivationType.
|
||||
activation_map = {
|
||||
MoEActivation.SILU: ActivationType.Swiglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
fi_activation_type: ActivationType = activation_map[layer.activation]
|
||||
|
||||
# DeepSeekV3 routing requires float32 logits; others expect bfloat16.
|
||||
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
assert router_logits.dtype == torch.float32, (
|
||||
"DeepSeekV3 routing requires float32 router_logits, "
|
||||
f"got {router_logits.dtype}."
|
||||
)
|
||||
else:
|
||||
router_logits = router_logits.to(torch.bfloat16)
|
||||
|
||||
# Treat 0 as "unset" for compatibility with ungrouped routing configs.
|
||||
n_group = layer.num_expert_group or None
|
||||
topk_group = layer.topk_group or None
|
||||
|
||||
hidden_states_mxfp8, hidden_states_scale = mxfp8_e4m3_quantize(
|
||||
x,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
kwargs: dict = dict(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
hidden_states=hidden_states_mxfp8,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm1_weights_scale=layer.w13_weight_scale,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
gemm2_weights_scale=layer.w2_weight_scale,
|
||||
num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
# Keep Optional semantics: FlashInfer expects None for non-grouped
|
||||
# routing (e.g. Qwen3 Renormalize), not 0.
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
use_shuffled_weight=True,
|
||||
weight_layout=0,
|
||||
fp8_quantization_type=Fp8QuantizationType.MxFp8,
|
||||
)
|
||||
|
||||
if fi_activation_type != ActivationType.Swiglu:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer TRTLLM MXFP8 MoE supports only Swiglu activation, "
|
||||
f"got {fi_activation_type}."
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
raise NotImplementedError(
|
||||
"Non-monolithic MXFP8 MoE path is not yet implemented."
|
||||
)
|
||||
|
||||
|
||||
# Register the method classes for ModelOptMxFp8Config
|
||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||
ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user