[Platform] Custom ops support for FusedMoe (#22509)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
@@ -159,7 +160,8 @@ def get_masked_input_and_mask(
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
@CustomOp.register("vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
|
||||
Reference in New Issue
Block a user