[Platform] Custom ops support for LMhead and LogitsProcessor (#23564)

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
zzhxxx
2025-09-10 21:26:31 +08:00
committed by GitHub
parent 2eb9986a2d
commit 736569da8d
2 changed files with 4 additions and 2 deletions

View File

@@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module):
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following: