Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
95 lines
2.5 KiB
Python
95 lines
2.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.
|
|
|
|
These ops use mutates_args to create data dependencies that prevent
|
|
the compiler from reordering prefetch/sync operations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.offloader.base import get_offloader
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
# --- wait_prefetch op ---
|
|
|
|
|
|
def _wait_prefetch_impl(
|
|
input_tensor: torch.Tensor,
|
|
layer_idx: int,
|
|
) -> None:
|
|
"""Wait for prefetch of layer_idx to complete.
|
|
|
|
Synchronizes the compute stream with the copy stream to ensure
|
|
the prefetched weights are ready for use.
|
|
|
|
Args:
|
|
input_tensor: Input to the layer (e.g., hidden_states) - declared
|
|
as mutated to create data dependency for torch.compile.
|
|
layer_idx: Index of the layer to wait for.
|
|
"""
|
|
get_offloader()._wait_for_layer(layer_idx)
|
|
|
|
|
|
def _wait_prefetch_fake(
|
|
input_tensor: torch.Tensor,
|
|
layer_idx: int,
|
|
) -> None:
|
|
"""Fake implementation for torch.compile tracing."""
|
|
return
|
|
|
|
|
|
# --- start_prefetch op ---
|
|
|
|
|
|
def _start_prefetch_impl(
|
|
output_tensor: torch.Tensor,
|
|
layer_idx: int,
|
|
) -> None:
|
|
"""Start async prefetch of layer_idx weights.
|
|
|
|
Initiates H2D copy on the copy stream for the specified layer.
|
|
|
|
Args:
|
|
output_tensor: Output from forward - declared as mutated to
|
|
prevent torch.compile from reordering this op before the
|
|
computation that produces output_tensor.
|
|
layer_idx: Index of the layer to prefetch.
|
|
"""
|
|
get_offloader()._start_prefetch(layer_idx)
|
|
|
|
|
|
def _start_prefetch_fake(
|
|
output_tensor: torch.Tensor,
|
|
layer_idx: int,
|
|
) -> None:
|
|
"""Fake implementation for torch.compile tracing."""
|
|
return
|
|
|
|
|
|
def register_prefetch_offloader_ops() -> None:
|
|
"""Register custom ops for prefetch offloader.
|
|
|
|
Must be called before the ops are used. This is typically done
|
|
at module import time.
|
|
"""
|
|
direct_register_custom_op(
|
|
op_name="wait_prefetch",
|
|
op_func=_wait_prefetch_impl,
|
|
mutates_args=["input_tensor"],
|
|
fake_impl=_wait_prefetch_fake,
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="start_prefetch",
|
|
op_func=_start_prefetch_impl,
|
|
mutates_args=["output_tensor"],
|
|
fake_impl=_start_prefetch_fake,
|
|
)
|
|
|
|
|
|
# Register ops at module import time
|
|
register_prefetch_offloader_ops()
|