Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,10 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@@ -21,6 +25,8 @@ def rms_norm(
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
@@ -39,6 +45,10 @@ def fused_add_rms_norm(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return rms_norm_batch_invariant(
|
||||
x + residual, weight, variance_epsilon
|
||||
), x + residual
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
|
||||
Reference in New Issue
Block a user