Deepseek-v3 Batch Invariant on 8xH100 (#26609)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Bram Wasti
2025-10-15 22:06:02 -07:00
committed by GitHub
parent 785d8b6410
commit 7d8975de84
21 changed files with 1567 additions and 102 deletions

View File

@@ -8,6 +8,10 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@@ -21,6 +25,8 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
out,
@@ -39,6 +45,10 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
ops.fused_add_rms_norm(
x,
residual,