Disable Cascade Attention for Batch Invariance (#32561)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -18,11 +18,18 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
linear_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_unquantized_gemm,
|
||||
is_layer_moe_router_gate,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
@@ -236,6 +243,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
vllm_is_batch_invariant()
|
||||
and current_platform.is_cuda_alike()
|
||||
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
|
||||
):
|
||||
return linear_batch_invariant(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user