diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index ee75d627d..851546297 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import inspect + import torch import torch.nn as nn @@ -205,9 +208,9 @@ class CustomOp(nn.Module): NOTE: this does not enable fusion across ops, so opaque custom ops should still be unwrapped wherever possible. """ - # Do not compile if compilation disabled from vllm.config.compilation import CompilationMode + # Do not compile if compilation disabled if not enable: return fn @@ -220,14 +223,42 @@ class CustomOp(nn.Module): if compilation_config.backend == "eager": return fn + compile_options = maybe_disable_graph_partition( + current_platform.simple_compile_backend + ) + backend = current_platform.simple_compile_backend + + dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None) + if dynamic_arg_dims is not None: + compiled_fn = torch.compile( + fn, + dynamic=False, + backend=backend, + options=compile_options, + ) + sig = inspect.signature(fn) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + for name, dims in dynamic_arg_dims.items(): + arg = bound.arguments.get(name) + if arg is not None and isinstance(arg, torch.Tensor): + dims_list = [dims] if isinstance(dims, int) else dims + for d in dims_list: + real_d = arg.ndim + d if d < 0 else d + torch._dynamo.mark_dynamic(arg, real_d) + return compiled_fn(*args, **kwargs) + + return wrapper + # dynamic=True to avoid recompilations return torch.compile( fn, dynamic=True, - backend=current_platform.simple_compile_backend, - options=maybe_disable_graph_partition( - current_platform.simple_compile_backend - ), + backend=backend, + options=compile_options, ) @classmethod @@ -267,10 +298,15 @@ class CustomOp(nn.Module): # Decorator to register custom ops. @classmethod - def register(cls, name: str): + def register( + cls, + name: str, + dynamic_arg_dims: dict[str, int | list[int]] | None = None, + ): def decorator(op_cls): assert name not in op_registry, f"Duplicate op name: {name}" op_cls.name = name + op_cls._dynamic_arg_dims = dynamic_arg_dims op_registry[name] = op_cls return op_cls diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 98ff02e9d..faebad596 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -950,7 +950,10 @@ def dynamic_per_batched_tensor_quant( logger = init_logger(__name__) -@CustomOp.register("mla_decode_concat_quant_fp8") +@CustomOp.register( + "mla_decode_concat_quant_fp8", + dynamic_arg_dims={"decode_ql_nope": 0, "decode_q_pe": 0}, +) class _DecodeConcatQuantFP8(QuantFP8): """ QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before